This is an automated email from the ASF dual-hosted git repository. jakevin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 4cfe990095 [enhancement](Nereids) add test framework for otherjoin (#21887) 4cfe990095 is described below commit 4cfe990095d32b99bc0b784ac5a9a3d17b9b5e01 Author: 谢健 <jianx...@gmail.com> AuthorDate: Thu Jul 20 16:35:55 2023 +0800 [enhancement](Nereids) add test framework for otherjoin (#21887) --- .../jobs/joinorder/hypergraph/OtherJoinTest.java | 56 +++ .../doris/nereids/util/HyperGraphBuilder.java | 416 ++++++++++++++++++++- 2 files changed, 462 insertions(+), 10 deletions(-) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/OtherJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/OtherJoinTest.java new file mode 100644 index 0000000000..feeb971b15 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/jobs/joinorder/hypergraph/OtherJoinTest.java @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.jobs.joinorder.hypergraph; + +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.datasets.tpch.TPCHTestBase; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.util.HyperGraphBuilder; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Set; + +public class OtherJoinTest extends TPCHTestBase { + @Test + public void randomTest() { + HyperGraphBuilder hyperGraphBuilder = new HyperGraphBuilder(); + Plan plan = hyperGraphBuilder + .randomBuildPlanWith(10, 20); + Set<List<Integer>> res1 = hyperGraphBuilder.evaluate(plan); + CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(connectContext, plan); + hyperGraphBuilder.initStats(cascadesContext); + Plan optimizedPlan = PlanChecker.from(cascadesContext) + .dpHypOptimize() + .getBestPlanTree(); + + Set<List<Integer>> res2 = hyperGraphBuilder.evaluate(optimizedPlan); + if (!res1.equals(res2)) { + System.out.println(res1); + System.out.println(res2); + System.out.println(plan.treeString()); + System.out.println(optimizedPlan.treeString()); + } + Assertions.assertTrue(res1.equals(res2)); + + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/HyperGraphBuilder.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/HyperGraphBuilder.java index 1fbdcf8954..a834f2bd92 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/HyperGraphBuilder.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/HyperGraphBuilder.java @@ -22,6 +22,7 @@ import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.jobs.joinorder.JoinOrderJob; import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph; import org.apache.doris.nereids.memo.Group; +import org.apache.doris.nereids.memo.GroupExpression; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; @@ -30,31 +31,81 @@ import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalJoin; +import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan; import org.apache.doris.statistics.ColumnStatistic; import org.apache.doris.statistics.Statistics; import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.BitSet; import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; public class HyperGraphBuilder { private final List<Integer> rowCounts = new ArrayList<>(); + private final List<LogicalOlapScan> tables = new ArrayList<>(); private final HashMap<BitSet, LogicalPlan> plans = new HashMap<>(); private final HashMap<BitSet, List<Integer>> schemas = new HashMap<>(); + private final ImmutableList<JoinType> fullJoinTypes = ImmutableList.of( + JoinType.INNER_JOIN, + JoinType.LEFT_OUTER_JOIN, + JoinType.RIGHT_OUTER_JOIN, + JoinType.FULL_OUTER_JOIN + ); + + private final ImmutableList<JoinType> leftFullJoinTypes = ImmutableList.of( + JoinType.INNER_JOIN, + JoinType.LEFT_OUTER_JOIN, + JoinType.RIGHT_OUTER_JOIN, + JoinType.FULL_OUTER_JOIN, + JoinType.LEFT_SEMI_JOIN, + JoinType.LEFT_ANTI_JOIN, + JoinType.NULL_AWARE_LEFT_ANTI_JOIN + ); + + private final ImmutableList<JoinType> rightFullJoinTypes = ImmutableList.of( + JoinType.INNER_JOIN, + JoinType.LEFT_OUTER_JOIN, + JoinType.RIGHT_OUTER_JOIN, + JoinType.FULL_OUTER_JOIN, + JoinType.RIGHT_SEMI_JOIN, + JoinType.RIGHT_ANTI_JOIN + ); + public HyperGraph build() { assert plans.size() == 1 : "there are cross join"; Plan plan = plans.values().iterator().next(); return buildHyperGraph(plan); } + public Plan buildJoinPlan() { + assert plans.size() == 1 : "there are cross join"; + Plan plan = plans.values().iterator().next(); + return buildPlanWithJoinType(plan, new BitSet()); + } + + public Plan randomBuildPlanWith(int tableNum, int edgeNum) { + randomBuildInit(tableNum, edgeNum); + return this.buildJoinPlan(); + } + public HyperGraph randomBuildWith(int tableNum, int edgeNum) { + randomBuildInit(tableNum, edgeNum); + return this.build(); + } + + private void randomBuildInit(int tableNum, int edgeNum) { Preconditions.checkArgument(edgeNum >= tableNum - 1, String.format("We can't build a connected graph with %d tables %d edges", tableNum, edgeNum)); Preconditions.checkArgument(edgeNum <= tableNum * (tableNum - 1) / 2, @@ -93,7 +144,6 @@ public class HyperGraphBuilder { int right = keys[i].nextSetBit(0); this.addEdge(JoinType.INNER_JOIN, left, right); } - return this.build(); } public HyperGraphBuilder init(int... rowCounts) { @@ -101,7 +151,9 @@ public class HyperGraphBuilder { this.rowCounts.add(rowCounts[i]); BitSet bitSet = new BitSet(); bitSet.set(i); - plans.put(bitSet, PlanConstructor.newLogicalOlapScan(i, String.valueOf(i), 0)); + LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan(i, String.valueOf(i), 0); + plans.put(bitSet, scan); + tables.add(scan); List<Integer> schema = new ArrayList<>(); schema.add(i); schemas.put(bitSet, schema); @@ -109,6 +161,17 @@ public class HyperGraphBuilder { return this; } + public void initStats(CascadesContext context) { + for (Group group : context.getMemo().getGroups()) { + GroupExpression groupExpression = group.getLogicalExpression(); + if (groupExpression.getPlan() instanceof LogicalOlapScan) { + Statistics stats = injectRowcount((LogicalOlapScan) groupExpression.getPlan()); + groupExpression.setStatDerived(true); + group.setStatistics(stats); + } + } + } + public HyperGraphBuilder addEdge(JoinType joinType, int node1, int node2) { Preconditions.checkArgument(node1 >= 0 && node1 < rowCounts.size(), String.format("%d must in [%d, %d)", node1, 0, rowCounts.size())); @@ -151,6 +214,41 @@ public class HyperGraphBuilder { return this; } + private Plan buildPlanWithJoinType(Plan plan, BitSet requireTable) { + if (!(plan instanceof LogicalJoin)) { + return plan; + } + LogicalJoin<? extends Plan, ? extends Plan> join = (LogicalJoin) plan; + BitSet leftSchema = findPlanSchema(join.left()); + BitSet rightSchema = findPlanSchema(join.right()); + JoinType joinType; + if (isSubset(requireTable, leftSchema)) { + int index = (int) (Math.random() * leftFullJoinTypes.size()); + joinType = leftFullJoinTypes.get(index); + } else if (isSubset(requireTable, rightSchema)) { + int index = (int) (Math.random() * rightFullJoinTypes.size()); + joinType = rightFullJoinTypes.get(index); + } else { + int index = (int) (Math.random() * fullJoinTypes.size()); + joinType = fullJoinTypes.get(index); + } + Set<Slot> requireSlots = join.getExpressions().stream() + .flatMap(expr -> expr.getInputSlots().stream()) + .collect(Collectors.toSet()); + for (int i = 0; i < tables.size(); i++) { + if (tables.get(i).getOutput().stream().anyMatch(slot -> requireSlots.contains(slot))) { + requireTable.set(i); + } + } + + Plan left = buildPlanWithJoinType(join.left(), requireTable); + Plan right = buildPlanWithJoinType(join.right(), requireTable); + Set<Slot> outputs = Stream.concat(left.getOutput().stream(), right.getOutput().stream()) + .collect(Collectors.toSet()); + assert outputs.containsAll(requireSlots); + return ((LogicalJoin) join.withChildren(left, right)).withJoinType(joinType); + } + private Optional<BitSet> findPlan(BitSet bitSet) { for (BitSet key : plans.keySet()) { if (isSubset(bitSet, key)) { @@ -160,6 +258,24 @@ public class HyperGraphBuilder { return Optional.empty(); } + private BitSet findPlanSchema(Plan plan) { + BitSet bitSet = new BitSet(); + if (plan instanceof LogicalOlapScan) { + for (int i = 0; i < tables.size(); i++) { + if (tables.get(i).equals(plan)) { + bitSet.set(i); + } + } + assert !bitSet.isEmpty(); + return bitSet; + } + + bitSet.or(findPlanSchema(((LogicalJoin) plan).left())); + bitSet.or(findPlanSchema(((LogicalJoin) plan).right())); + assert !bitSet.isEmpty(); + return bitSet; + } + private boolean isSubset(BitSet bitSet1, BitSet bitSet2) { BitSet bitSet = new BitSet(); bitSet.or(bitSet1); @@ -182,14 +298,7 @@ public class HyperGraphBuilder { private void injectRowcount(Group group) { if (!group.isInnerJoinGroup()) { LogicalOlapScan scanPlan = (LogicalOlapScan) group.getLogicalExpression().getPlan(); - HashMap<Expression, ColumnStatistic> slotIdToColumnStats = new HashMap<Expression, ColumnStatistic>(); - int count = rowCounts.get(Integer.parseInt(scanPlan.getTable().getName())); - for (Slot slot : scanPlan.getOutput()) { - slotIdToColumnStats.put(slot, - new ColumnStatistic(count, count, null, 0, 0, 0, 0, - 0, 0, null, null, true, null)); - } - Statistics stats = new Statistics(count, slotIdToColumnStats); + Statistics stats = injectRowcount(scanPlan); group.setStatistics(stats); return; } @@ -197,6 +306,17 @@ public class HyperGraphBuilder { injectRowcount(group.getLogicalExpression().child(1)); } + private Statistics injectRowcount(LogicalOlapScan scanPlan) { + HashMap<Expression, ColumnStatistic> slotIdToColumnStats = new HashMap<Expression, ColumnStatistic>(); + int count = rowCounts.get(Integer.parseInt(scanPlan.getTable().getName())); + for (Slot slot : scanPlan.getOutput()) { + slotIdToColumnStats.put(slot, + new ColumnStatistic(count, count, null, 1, 0, 0, 0, + count, 1, null, null, true, null)); + } + return new Statistics(count, slotIdToColumnStats); + } + private void constructJoin(int node1, int node2, BitSet key) { LogicalJoin join = (LogicalJoin) plans.get(key); Expression condition = makeCondition(node1, node2, key); @@ -241,4 +361,280 @@ public class HyperGraphBuilder { new EqualTo(plan.getOutput().get(leftIndex), plan.getOutput().get(rightIndex)); return hashConjunts; } + + public Set<List<Integer>> evaluate(Plan plan) { + JoinEvaluator evaluator = new JoinEvaluator(rowCounts); + Map<Slot, List<Integer>> res = evaluator.evaluate(plan); + int rowCount = 0; + if (res.size() > 0) { + rowCount = res.values().iterator().next().size(); + } + List<Slot> keySet = res.keySet().stream() + .sorted( + (slot1, slot2) -> + String.CASE_INSENSITIVE_ORDER.compare(slot1.toString(), slot2.toString())) + .collect(Collectors.toList()); + Set<List<Integer>> tuples = new HashSet<>(); + for (int i = 0; i < rowCount; i++) { + List<Integer> tuple = new ArrayList<>(); + for (Slot key : keySet) { + tuple.add(res.get(key).get(i)); + } + tuples.add(tuple); + } + return tuples; + } + + class JoinEvaluator { + List<Integer> rowCounts; + + JoinEvaluator(List<Integer> rowCounts) { + this.rowCounts = rowCounts; + } + + Map<Slot, List<Integer>> evaluate(Plan plan) { + if (plan instanceof LogicalOlapScan || plan instanceof PhysicalOlapScan) { + return evaluateScan(plan); + } + if (plan instanceof LogicalJoin || plan instanceof AbstractPhysicalJoin) { + return evaluateJoin(plan); + } + assert plan.children().size() == 1; + return evaluate(plan.child(0)); + } + + public Map<Slot, List<Integer>> evaluateScan(Plan plan) { + String name; + if (plan instanceof LogicalOlapScan) { + name = ((LogicalOlapScan) plan).getTable().getName(); + } else { + Preconditions.checkArgument(plan instanceof PhysicalOlapScan); + name = ((PhysicalOlapScan) plan).getTable().getName(); + } + int rowCount = rowCounts.get(Integer.parseInt(name)); + Map<Slot, List<Integer>> rows = new HashMap<>(); + for (Slot slot : plan.getOutput()) { + rows.put(slot, new ArrayList<>()); + for (int i = 0; i < rowCount; i++) { + rows.get(slot).add(i); + } + } + return rows; + } + + public Map<Slot, List<Integer>> evaluateJoin(Plan plan) { + Map<Slot, List<Integer>> left; + Map<Slot, List<Integer>> right; + List<? extends Expression> expressions = plan.getExpressions(); + JoinType joinType; + if (plan instanceof LogicalJoin) { + left = this.evaluate(((LogicalJoin<?, ?>) plan).left()); + right = this.evaluate(((LogicalJoin<?, ?>) plan).right()); + joinType = ((LogicalJoin<?, ?>) plan).getJoinType(); + } else { + Preconditions.checkArgument(plan instanceof AbstractPhysicalJoin); + left = this.evaluate(((AbstractPhysicalJoin<?, ?>) plan).left()); + right = this.evaluate(((AbstractPhysicalJoin<?, ?>) plan).right()); + joinType = ((AbstractPhysicalJoin<?, ?>) plan).getJoinType(); + } + + List<Pair<Integer, Integer>> matchPair = new ArrayList<>(); + for (int i = 0; i < getTableRC(left); i++) { + for (int j = 0; j < getTableRC(right); j++) { + int leftIndex = i; + int rightIndex = j; + Boolean matched = true; + for (Expression expr : expressions) { + Boolean res = evaluateExpr(joinType, expr, left, leftIndex, right, rightIndex); + if (res == null) { + matched = null; + } else if (res == false) { + matched = false; + break; + } + } + if (matched == null) { + // NAAJ return nothing when right has null + for (int i1 = 0; i1 < getTableRC(left); i1++) { + for (int j1 = 0; j1 < getTableRC(right); j1++) { + matchPair.add(Pair.of(i1, j1)); + } + } + return calJoin(joinType, left, right, matchPair); + } + if (matched) { + matchPair.add(Pair.of(i, j)); + } + } + } + return calJoin(joinType, left, right, matchPair); + + } + + Map<Slot, List<Integer>> calJoin(JoinType joinType, Map<Slot, List<Integer>> left, + Map<Slot, List<Integer>> right, List<Pair<Integer, Integer>> matchPair) { + switch (joinType) { + case INNER_JOIN: + return calIJ(left, right, matchPair); + case LEFT_OUTER_JOIN: + return calLOJ(left, right, matchPair); + case RIGHT_OUTER_JOIN: + return calLOJ(right, left, + matchPair.stream().map(p -> Pair.of(p.second, p.first)).collect(Collectors.toList())); + case FULL_OUTER_JOIN: + return calFOJ(left, right, matchPair); + case LEFT_SEMI_JOIN: + return calLSJ(left, right, matchPair); + case RIGHT_SEMI_JOIN: + return calLSJ(right, left, + matchPair.stream().map(p -> Pair.of(p.second, p.first)).collect(Collectors.toList())); + case LEFT_ANTI_JOIN: + return calLAJ(left, right, matchPair); + case RIGHT_ANTI_JOIN: + return calLAJ(right, left, + matchPair.stream().map(p -> Pair.of(p.second, p.first)).collect(Collectors.toList())); + case NULL_AWARE_LEFT_ANTI_JOIN: + return calLNAAJ(left, right, matchPair); + default: + assert false; + } + assert false; + return new HashMap<>(); + } + + Map<Slot, List<Integer>> calIJ(Map<Slot, List<Integer>> left, + Map<Slot, List<Integer>> right, List<Pair<Integer, Integer>> matchPair) { + Map<Slot, List<Integer>> outputs = new HashMap<>(); + for (Slot slot : left.keySet()) { + outputs.put(slot, new ArrayList<>()); + } + for (Slot slot : right.keySet()) { + outputs.put(slot, new ArrayList<>()); + } + for (Pair<Integer, Integer> p : matchPair) { + for (Slot slot : left.keySet()) { + outputs.get(slot).add(left.get(slot).get(p.first)); + } + for (Slot slot : right.keySet()) { + outputs.get(slot).add(right.get(slot).get(p.second)); + } + } + return outputs; + } + + Map<Slot, List<Integer>> calFOJ(Map<Slot, List<Integer>> left, + Map<Slot, List<Integer>> right, List<Pair<Integer, Integer>> matchPair) { + Map<Slot, List<Integer>> outputs = calIJ(left, right, matchPair); + Set<Integer> leftIndices = matchPair.stream().map(p -> p.first).collect(Collectors.toSet()); + Set<Integer> rightIndices = matchPair.stream().map(p -> p.second).collect(Collectors.toSet()); + + for (int i = 0; i < getTableRC(left); i++) { + if (leftIndices.contains(i)) { + continue; + } + for (Slot slot : left.keySet()) { + outputs.get(slot).add(left.get(slot).get(i)); + } + for (Slot slot : right.keySet()) { + outputs.get(slot).add(null); + } + } + + for (int i = 0; i < getTableRC(right); i++) { + if (rightIndices.contains(i)) { + continue; + } + for (Slot slot : left.keySet()) { + outputs.get(slot).add(null); + } + for (Slot slot : right.keySet()) { + outputs.get(slot).add(right.get(slot).get(i)); + } + } + + return outputs; + } + + Map<Slot, List<Integer>> calLOJ(Map<Slot, List<Integer>> left, + Map<Slot, List<Integer>> right, List<Pair<Integer, Integer>> matchPair) { + Map<Slot, List<Integer>> outputs = calIJ(left, right, matchPair); + Set<Integer> leftIndices = matchPair.stream().map(p -> p.first).collect(Collectors.toSet()); + for (int i = 0; i < getTableRC(left); i++) { + if (leftIndices.contains(i)) { + continue; + } + for (Slot slot : left.keySet()) { + outputs.get(slot).add(left.get(slot).get(i)); + } + for (Slot slot : right.keySet()) { + outputs.get(slot).add(null); + } + } + return outputs; + } + + Map<Slot, List<Integer>> calLSJ(Map<Slot, List<Integer>> left, + Map<Slot, List<Integer>> right, List<Pair<Integer, Integer>> matchPair) { + Map<Slot, List<Integer>> outputs = new HashMap<>(); + for (Slot slot : left.keySet()) { + outputs.put(slot, new ArrayList<>()); + } + for (Pair<Integer, Integer> p : matchPair) { + for (Slot slot : left.keySet()) { + outputs.get(slot).add(left.get(slot).get(p.first)); + } + } + return outputs; + } + + Map<Slot, List<Integer>> calLAJ(Map<Slot, List<Integer>> left, + Map<Slot, List<Integer>> right, List<Pair<Integer, Integer>> matchPair) { + Map<Slot, List<Integer>> outputs = new HashMap<>(); + for (Slot slot : left.keySet()) { + outputs.put(slot, new ArrayList<>()); + } + Set<Integer> leftIndices = matchPair.stream().map(p -> p.first).collect(Collectors.toSet()); + for (int i = 0; i < getTableRC(left); i++) { + if (leftIndices.contains(i)) { + continue; + } + for (Slot slot : left.keySet()) { + outputs.get(slot).add(left.get(slot).get(i)); + } + } + return outputs; + } + + Map<Slot, List<Integer>> calLNAAJ(Map<Slot, List<Integer>> left, + Map<Slot, List<Integer>> right, List<Pair<Integer, Integer>> matchPair) { + return calLAJ(left, right, matchPair); + } + + Boolean evaluateExpr(JoinType joinType, Expression expr, Map<Slot, List<Integer>> left, int leftIndex, + Map<Slot, List<Integer>> right, int rightIndex) { + List<Slot> slots = Lists.newArrayList(expr.getInputSlots()); + Preconditions.checkArgument(slots.size() == 2); + Integer lv; + Integer rv; + if (left.containsKey(slots.get(0))) { + lv = left.get(slots.get(0)).get(leftIndex); + rv = right.get(slots.get(1)).get(rightIndex); + } else { + lv = right.get(slots.get(0)).get(rightIndex); + rv = left.get(slots.get(1)).get(leftIndex); + } + Boolean res = (lv == rv); + if (joinType.isNullAwareLeftAntiJoin()) { + res |= (lv == null); + } + if (joinType.isNullAwareLeftAntiJoin() && rv == null) { + res = null; + } + return res; + } + } + + private int getTableRC(Map<Slot, List<Integer>> m) { + return m.entrySet().iterator().next().getValue().size(); + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org