This is an automated email from the ASF dual-hosted git repository.
englefly pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new fa3bdbce966 [opt](nereids) enhance PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE
(#43856)
fa3bdbce966 is described below
commit fa3bdbce966512da6986046df8558c1d04e93f61
Author: minghong <[email protected]>
AuthorDate: Thu Nov 28 11:19:32 2024 +0800
[opt](nereids) enhance PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE (#43856)
### What problem does this PR solve?
PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE has some restrictions
do not support count(*)
do not support join with other join conditions
do not support the project between agg and join that contains non-slot
expressions
this pr removes above restrictions for pattern: agg-project-join
---
.../rewrite/PushDownAggThroughJoinOneSide.java | 123 +++++++++++++++------
.../rewrite/PushDownMinMaxSumThroughJoinTest.java | 16 ++-
.../push_down_count_through_join_one_side.out | 22 ++++
.../push_down_count_through_join_one_side.groovy | 95 ++++++++++++++++
4 files changed, 212 insertions(+), 44 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java
index f32bf8ea91b..c5d3d0fb49a 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOneSide.java
@@ -36,6 +36,7 @@ import
org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
+import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.HashMap;
@@ -74,8 +75,8 @@ public class PushDownAggThroughJoinOneSide implements
RewriteRuleFactory {
Set<AggregateFunction> funcs =
agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
.allMatch(f -> (f instanceof Min || f
instanceof Max || f instanceof Sum
- || (f instanceof Count &&
!((Count) f).isCountStar())) && !f.isDistinct()
- && f.child(0) instanceof Slot);
+ || f instanceof Count &&
!f.isDistinct()
+ && (f.children().isEmpty() ||
f.child(0) instanceof Slot)));
})
.thenApply(ctx -> {
Set<Integer> enableNereidsRules =
ctx.cascadesContext.getConnectContext()
@@ -88,15 +89,16 @@ public class PushDownAggThroughJoinOneSide implements
RewriteRuleFactory {
})
.toRule(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE),
logicalAggregate(logicalProject(innerLogicalJoin()))
- .when(agg -> agg.child().isAllSlots())
- .when(agg ->
agg.child().child().getOtherJoinConjuncts().isEmpty())
- .whenNot(agg ->
agg.child().children().stream().anyMatch(p -> p instanceof LogicalAggregate))
+ // .when(agg -> agg.child().isAllSlots())
+ // .when(agg ->
agg.child().child().getOtherJoinConjuncts().isEmpty())
+ .whenNot(agg -> agg.child()
+ .child(0).children().stream().anyMatch(p -> p
instanceof LogicalAggregate))
.when(agg -> {
Set<AggregateFunction> funcs =
agg.getAggregateFunctions();
return !funcs.isEmpty() && funcs.stream()
.allMatch(f -> (f instanceof Min || f
instanceof Max || f instanceof Sum
- || (f instanceof Count &&
(!((Count) f).isCountStar()))) && !f.isDistinct()
- && f.child(0) instanceof Slot);
+ || f instanceof Count) &&
!f.isDistinct()
+ && (f.children().isEmpty() ||
f.child(0) instanceof Slot));
})
.thenApply(ctx -> {
Set<Integer> enableNereidsRules =
ctx.cascadesContext.getConnectContext()
@@ -118,23 +120,6 @@ public class PushDownAggThroughJoinOneSide implements
RewriteRuleFactory {
LogicalJoin<Plan, Plan> join, List<NamedExpression> projects) {
List<Slot> leftOutput = join.left().getOutput();
List<Slot> rightOutput = join.right().getOutput();
-
- List<AggregateFunction> leftFuncs = new ArrayList<>();
- List<AggregateFunction> rightFuncs = new ArrayList<>();
- for (AggregateFunction func : agg.getAggregateFunctions()) {
- Slot slot = (Slot) func.child(0);
- if (leftOutput.contains(slot)) {
- leftFuncs.add(func);
- } else if (rightOutput.contains(slot)) {
- rightFuncs.add(func);
- } else {
- throw new IllegalStateException("Slot " + slot + " not found
in join output");
- }
- }
- if (leftFuncs.isEmpty() && rightFuncs.isEmpty()) {
- return null;
- }
-
Set<Slot> leftGroupBy = new HashSet<>();
Set<Slot> rightGroupBy = new HashSet<>();
for (Expression e : agg.getGroupByExpressions()) {
@@ -144,18 +129,71 @@ public class PushDownAggThroughJoinOneSide implements
RewriteRuleFactory {
} else if (rightOutput.contains(slot)) {
rightGroupBy.add(slot);
} else {
- return null;
+ if (projects.isEmpty()) {
+ // TODO: select ... from ... group by A , B, 1.2; 1.2 is
constant
+ return null;
+ } else {
+ for (NamedExpression proj : projects) {
+ if (proj instanceof Alias &&
proj.toSlot().equals(slot)) {
+ Set<Slot> inputForAliasSet = proj.getInputSlots();
+ for (Slot aliasInputSlot : inputForAliasSet) {
+ if (leftOutput.contains(aliasInputSlot)) {
+ leftGroupBy.add(aliasInputSlot);
+ } else if
(rightOutput.contains(aliasInputSlot)) {
+ rightGroupBy.add(aliasInputSlot);
+ } else {
+ return null;
+ }
+ }
+ break;
+ }
+ }
+ }
}
}
- join.getHashJoinConjuncts().forEach(e ->
e.getInputSlots().forEach(slot -> {
- if (leftOutput.contains(slot)) {
- leftGroupBy.add(slot);
- } else if (rightOutput.contains(slot)) {
- rightGroupBy.add(slot);
+
+ List<AggregateFunction> leftFuncs = new ArrayList<>();
+ List<AggregateFunction> rightFuncs = new ArrayList<>();
+ Count countStar = null;
+ Count rewrittenCountStar = null;
+ for (AggregateFunction func : agg.getAggregateFunctions()) {
+ if (func instanceof Count && ((Count) func).isCountStar()) {
+ countStar = (Count) func;
+ } else {
+ Slot slot = (Slot) func.child(0);
+ if (leftOutput.contains(slot)) {
+ leftFuncs.add(func);
+ } else if (rightOutput.contains(slot)) {
+ rightFuncs.add(func);
+ } else {
+ throw new IllegalStateException("Slot " + slot + " not
found in join output");
+ }
+ }
+ }
+ // rewrite count(*) to count(A), where A is slot from left/right group
by key
+ if (countStar != null) {
+ if (!leftGroupBy.isEmpty()) {
+ rewrittenCountStar = (Count)
countStar.withChildren(leftGroupBy.iterator().next());
+ leftFuncs.add(rewrittenCountStar);
+ } else if (!rightGroupBy.isEmpty()) {
+ rewrittenCountStar = (Count)
countStar.withChildren(rightGroupBy.iterator().next());
+ rightFuncs.add(rewrittenCountStar);
} else {
- throw new IllegalStateException("Slot " + slot + " not found
in join output");
+ return null;
+ }
+ }
+ for (Expression condition : join.getHashJoinConjuncts()) {
+ for (Slot joinConditionSlot : condition.getInputSlots()) {
+ if (leftOutput.contains(joinConditionSlot)) {
+ leftGroupBy.add(joinConditionSlot);
+ } else if (rightOutput.contains(joinConditionSlot)) {
+ rightGroupBy.add(joinConditionSlot);
+ } else {
+ // apply failed
+ return null;
+ }
}
- }));
+ }
Plan left = join.left();
Plan right = join.right();
@@ -196,6 +234,10 @@ public class PushDownAggThroughJoinOneSide implements
RewriteRuleFactory {
for (NamedExpression ne : agg.getOutputExpressions()) {
if (ne instanceof Alias && ((Alias) ne).child() instanceof
AggregateFunction) {
AggregateFunction func = (AggregateFunction) ((Alias)
ne).child();
+ if (func instanceof Count && ((Count) func).isCountStar()) {
+ // countStar is already rewritten as count(left_slot) or
count(right_slot)
+ func = rewrittenCountStar;
+ }
Slot slot = (Slot) func.child(0);
if (leftSlotToOutput.containsKey(slot)) {
Expression newFunc = replaceAggFunc(func,
leftSlotToOutput.get(slot).toSlot());
@@ -210,9 +252,20 @@ public class PushDownAggThroughJoinOneSide implements
RewriteRuleFactory {
newOutputExprs.add(ne);
}
}
-
- // TODO: column prune project
- return agg.withAggOutputChild(newOutputExprs, newJoin);
+ Plan newAggChild = newJoin;
+ if (agg.child() instanceof LogicalProject) {
+ LogicalProject project = (LogicalProject) agg.child();
+ List<NamedExpression> newProjections = Lists.newArrayList();
+ newProjections.addAll(project.getProjects());
+ Set<NamedExpression> leftDifference = new
HashSet<NamedExpression>(left.getOutput());
+ leftDifference.removeAll(project.getProjects());
+ newProjections.addAll(leftDifference);
+ Set<NamedExpression> rightDifference = new
HashSet<NamedExpression>(right.getOutput());
+ rightDifference.removeAll(project.getProjects());
+ newProjections.addAll(rightDifference);
+ newAggChild = ((LogicalProject)
agg.child()).withProjectsAndChild(newProjections, newJoin);
+ }
+ return agg.withAggOutputChild(newOutputExprs, newAggChild);
}
private static Expression replaceAggFunc(AggregateFunction func, Slot
inputSlot) {
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxSumThroughJoinTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxSumThroughJoinTest.java
index 58ab7fbe9e9..cffe91045d0 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxSumThroughJoinTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownMinMaxSumThroughJoinTest.java
@@ -323,11 +323,11 @@ class PushDownMinMaxSumThroughJoinTest implements
MemoPatternMatchSupported {
.applyTopDown(new PushDownAggThroughJoinOneSide())
.printlnTree()
.matches(
- logicalAggregate(
- logicalJoin(
- logicalOlapScan(),
+ logicalJoin(
+ logicalAggregate(
logicalOlapScan()
- )
+ ),
+ logicalOlapScan()
)
);
}
@@ -346,11 +346,9 @@ class PushDownMinMaxSumThroughJoinTest implements
MemoPatternMatchSupported {
PlanChecker.from(MemoTestUtils.createConnectContext(), plan)
.applyTopDown(new PushDownAggThroughJoinOneSide())
.matches(
- logicalAggregate(
- logicalJoin(
- logicalOlapScan(),
- logicalOlapScan()
- )
+ logicalJoin(
+ logicalAggregate(logicalOlapScan()),
+ logicalAggregate(logicalOlapScan())
)
);
}
diff --git
a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out
b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out
index da69919becd..8267eb3e38f 100644
---
a/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out
+++
b/regression-test/data/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.out
@@ -1034,3 +1034,25 @@ Used:
UnUsed: use_push_down_agg_through_join_one_side
SyntaxError:
+-- !shape --
+PhysicalResultSink
+--PhysicalTopN[MERGE_SORT]
+----PhysicalTopN[LOCAL_SORT]
+------hashAgg[GLOBAL]
+--------hashAgg[LOCAL]
+----------hashJoin[INNER_JOIN]
hashCondition=((dwd_tracking_sensor_init_tmp_ymd.dt =
dw_user_b2c_tracking_info_tmp_ymd.dt) and
(dwd_tracking_sensor_init_tmp_ymd.guid =
dw_user_b2c_tracking_info_tmp_ymd.guid))
otherCondition=((dwd_tracking_sensor_init_tmp_ymd.dt >=
substring(first_visit_time, 1, 10)))
+------------hashAgg[GLOBAL]
+--------------hashAgg[LOCAL]
+----------------filter((dwd_tracking_sensor_init_tmp_ymd.dt = '2024-08-19')
and (dwd_tracking_sensor_init_tmp_ymd.tracking_type = 'click'))
+------------------PhysicalOlapScan[dwd_tracking_sensor_init_tmp_ymd]
+------------filter((dw_user_b2c_tracking_info_tmp_ymd.dt = '2024-08-19'))
+--------------PhysicalOlapScan[dw_user_b2c_tracking_info_tmp_ymd]
+
+Hint log:
+Used: use_PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE
+UnUsed:
+SyntaxError:
+
+-- !agg_pushed --
+2 是 2024-08-19
+
diff --git
a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy
b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy
index 02e06710296..e551fa04c91 100644
---
a/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy
+++
b/regression-test/suites/nereids_rules_p0/eager_aggregate/push_down_count_through_join_one_side.groovy
@@ -426,4 +426,99 @@ suite("push_down_count_through_join_one_side") {
qt_with_hint_groupby_pushdown_nested_queries """
explain shape plan select /*+
USE_CBO_RULE(push_down_agg_through_join_one_side) */ count(*) from (select *
from count_t_one_side where score > 20) t1 join (select * from count_t_one_side
where id < 100) t2 on t1.id = t2.id group by t1.name;
"""
+
+ sql """
+ drop table if exists dw_user_b2c_tracking_info_tmp_ymd;
+ create table dw_user_b2c_tracking_info_tmp_ymd (
+ guid int,
+ dt varchar,
+ first_visit_time varchar
+ )Engine=Olap
+ DUPLICATE KEY(guid)
+ distributed by hash(dt) buckets 3
+ properties('replication_num' = '1');
+
+ insert into dw_user_b2c_tracking_info_tmp_ymd values (1, '2024-08-19',
'2024-08-19');
+
+ drop table if exists dwd_tracking_sensor_init_tmp_ymd;
+ create table dwd_tracking_sensor_init_tmp_ymd (
+ guid int,
+ dt varchar,
+ tracking_type varchar
+ )Engine=Olap
+ DUPLICATE KEY(guid)
+ distributed by hash(dt) buckets 3
+ properties('replication_num' = '1');
+
+ insert into dwd_tracking_sensor_init_tmp_ymd values(1, '2024-08-19',
'click'), (1, '2024-08-19', 'click');
+ """
+ sql """
+ set ENABLE_NEREIDS_RULES = "PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE";
+ set disable_join_reorder=true;
+ """
+
+ qt_shape """
+ explain shape plan
+ SELECT /*+use_cbo_rule(PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE)*/
+ Count(*) AS accee593,
+ CASE
+ WHEN dwd_tracking_sensor_init_tmp_ymd.dt =
+ Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1,
+ 10) THEN
+ '是'
+ WHEN dwd_tracking_sensor_init_tmp_ymd.dt >
+ Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1,
+ 10) THEN
+ '否'
+ ELSE '-1'
+ end AS a1302fb2,
+ dwd_tracking_sensor_init_tmp_ymd.dt AS ad466123
+ FROM dwd_tracking_sensor_init_tmp_ymd
+ LEFT JOIN dw_user_b2c_tracking_info_tmp_ymd
+ ON dwd_tracking_sensor_init_tmp_ymd.guid =
+ dw_user_b2c_tracking_info_tmp_ymd.guid
+ AND dwd_tracking_sensor_init_tmp_ymd.dt =
+ dw_user_b2c_tracking_info_tmp_ymd.dt
+ WHERE dwd_tracking_sensor_init_tmp_ymd.dt = '2024-08-19'
+ AND dw_user_b2c_tracking_info_tmp_ymd.dt = '2024-08-19'
+ AND dwd_tracking_sensor_init_tmp_ymd.dt >=
+ Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1,
10)
+ AND dwd_tracking_sensor_init_tmp_ymd.tracking_type = 'click'
+ GROUP BY 2,
+ 3
+ ORDER BY 3 ASC
+ LIMIT 10000;
+ """
+
+ qt_agg_pushed """
+ SELECT /*+use_cbo_rule(PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE)*/
+ Count(*) AS accee593,
+ CASE
+ WHEN dwd_tracking_sensor_init_tmp_ymd.dt =
+ Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1,
+ 10) THEN
+ '是'
+ WHEN dwd_tracking_sensor_init_tmp_ymd.dt >
+ Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1,
+ 10) THEN
+ '否'
+ ELSE '-1'
+ end AS a1302fb2,
+ dwd_tracking_sensor_init_tmp_ymd.dt AS ad466123
+ FROM dwd_tracking_sensor_init_tmp_ymd
+ LEFT JOIN dw_user_b2c_tracking_info_tmp_ymd
+ ON dwd_tracking_sensor_init_tmp_ymd.guid =
+ dw_user_b2c_tracking_info_tmp_ymd.guid
+ AND dwd_tracking_sensor_init_tmp_ymd.dt =
+ dw_user_b2c_tracking_info_tmp_ymd.dt
+ WHERE dwd_tracking_sensor_init_tmp_ymd.dt = '2024-08-19'
+ AND dw_user_b2c_tracking_info_tmp_ymd.dt = '2024-08-19'
+ AND dwd_tracking_sensor_init_tmp_ymd.dt >=
+ Substring(dw_user_b2c_tracking_info_tmp_ymd.first_visit_time, 1,
10)
+ AND dwd_tracking_sensor_init_tmp_ymd.tracking_type = 'click'
+ GROUP BY 2,
+ 3
+ ORDER BY 3 ASC
+ LIMIT 10000;
+ """
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]