This is an automated email from the ASF dual-hosted git repository.
jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new f600f70619 [ehancement](fe) Tune for stats framework (#17860)
f600f70619 is described below
commit f600f7061988e6ad6182243aef6beb13dc39846e
Author: AKIRA <[email protected]>
AuthorDate: Wed Mar 22 12:07:56 2023 +0900
[ehancement](fe) Tune for stats framework (#17860)
---
.../doris/nereids/stats/FilterEstimation.java | 49 +++++++++++++++++-----
.../apache/doris/nereids/stats/JoinEstimation.java | 23 ++++++----
.../doris/nereids/stats/StatsCalculator.java | 41 ++++++++++++++----
.../doris/nereids/stats/StatsErrorEstimator.java | 2 +-
.../doris/statistics/StatisticConstants.java | 2 +
.../org/apache/doris/statistics/Statistics.java | 11 +++++
.../doris/nereids/stats/FilterEstimationTest.java | 8 +---
tools/qerror.py | 28 ++++++++++---
8 files changed, 125 insertions(+), 39 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java
index e2159f1040..1aaaeaa955 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java
@@ -18,6 +18,7 @@
package org.apache.doris.nereids.stats;
import org.apache.doris.nereids.stats.FilterEstimation.EstimationContext;
+import org.apache.doris.nereids.trees.TreeNode;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
@@ -33,6 +34,7 @@ import
org.apache.doris.nereids.trees.expressions.NullSafeEqual;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.functions.Function;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.statistics.Bucket;
@@ -49,6 +51,7 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
+import java.util.function.Predicate;
/**
* Calculate selectivity of expression that produces boolean value.
@@ -56,9 +59,21 @@ import java.util.Set;
*/
public class FilterEstimation extends ExpressionVisitor<Statistics,
EstimationContext> {
public static final double DEFAULT_INEQUALITY_COEFFICIENT = 0.5;
+ public static final double DEFAULT_IN_COEFFICIENT = 1.0 / 3.0;
+
+ public static final double DEFAULT_HAVING_COEFFICIENT = 0.01;
public static final double DEFAULT_EQUALITY_COMPARISON_SELECTIVITY = 0.1;
+ private Set<Slot> aggSlots;
+
+ public FilterEstimation() {
+ }
+
+ public FilterEstimation(Set<Slot> aggSlots) {
+ this.aggSlots = aggSlots;
+ }
+
/**
* This method will update the stats according to the selectivity.
*/
@@ -104,7 +119,6 @@ public class FilterEstimation extends
ExpressionVisitor<Statistics, EstimationCo
estimatedColStatsBuilder.setMaxValue(rightColStats.maxValue);
estimatedColStatsBuilder.setMaxExpr(rightColStats.maxExpr);
}
- orStats.addColumnStats(entry.getKey(),
estimatedColStatsBuilder.build());
}
return orStats;
}
@@ -127,6 +141,24 @@ public class FilterEstimation extends
ExpressionVisitor<Statistics, EstimationCo
}
ColumnStatistic statsForLeft = ExpressionEstimation.estimate(left,
context.statistics);
ColumnStatistic statsForRight = ExpressionEstimation.estimate(right,
context.statistics);
+ if (aggSlots != null) {
+ Predicate<TreeNode<Expression>> containsAggSlot = e -> {
+ if (e instanceof SlotReference) {
+ SlotReference slot = (SlotReference) e;
+ return aggSlots.contains(slot);
+ }
+ return false;
+ };
+ boolean leftAgg = left.anyMatch(containsAggSlot);
+ boolean rightAgg = right.anyMatch(containsAggSlot);
+ // It means this predicate appears in HAVING clause.
+ if (leftAgg || rightAgg) {
+ double rowCount = context.statistics.getRowCount();
+ double newRowCount = Math.max(rowCount *
DEFAULT_HAVING_COEFFICIENT,
+ Math.max(statsForLeft.ndv, statsForRight.ndv));
+ return context.statistics.withRowCount(newRowCount);
+ }
+ }
if (!(left instanceof Literal) && !(right instanceof Literal)) {
return calculateWhenBothColumn(cp, context, statsForLeft,
statsForRight);
} else {
@@ -167,14 +199,11 @@ public class FilterEstimation extends
ExpressionVisitor<Statistics, EstimationCo
double ndv = statsForLeft.ndv;
double val = statsForRight.maxValue;
if (cp instanceof EqualTo || cp instanceof NullSafeEqual) {
- if (statsForLeft == ColumnStatistic.UNKNOWN) {
- selectivity = DEFAULT_EQUALITY_COMPARISON_SELECTIVITY;
+
+ if (val > statsForLeft.maxValue || val < statsForLeft.minValue) {
+ selectivity = 0.0;
} else {
- if (val > statsForLeft.maxValue || val <
statsForLeft.minValue) {
- selectivity = 0.0;
- } else {
- selectivity = StatsMathUtil.minNonNaN(1.0, 1.0 / ndv);
- }
+ selectivity = StatsMathUtil.minNonNaN(1.0, 1.0 / ndv);
}
if (context.isNot) {
selectivity = 1 - selectivity;
@@ -249,8 +278,8 @@ public class FilterEstimation extends
ExpressionVisitor<Statistics, EstimationCo
boolean isNotIn = context != null && context.isNot;
Expression compareExpr = inPredicate.getCompareExpr();
ColumnStatistic compareExprStats =
ExpressionEstimation.estimate(compareExpr, context.statistics);
- if (compareExprStats.isUnKnown) {
- return context.statistics.withSel(DEFAULT_INEQUALITY_COEFFICIENT);
+ if (compareExprStats.isUnKnown || compareExpr instanceof Function) {
+ return context.statistics.withSel(DEFAULT_IN_COEFFICIENT);
}
List<Expression> options = inPredicate.getOptions();
double maxOption = 0;
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java
index 0e8516e00c..d1427ef469 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java
@@ -21,10 +21,10 @@ import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.algebra.Join;
-import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.statistics.Statistics;
import org.apache.doris.statistics.StatisticsBuilder;
+import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
@@ -33,6 +33,7 @@ import java.util.stream.Collectors;
* TODO: Update other props in the ColumnStats properly.
*/
public class JoinEstimation {
+
private static Statistics estimateInnerJoin(Statistics crossJoinStats,
List<Expression> joinConditions) {
List<Pair<Expression, Double>> sortedJoinConditions =
joinConditions.stream()
.map(expression -> Pair.of(expression,
estimateJoinConditionSel(crossJoinStats, expression)))
@@ -51,7 +52,7 @@ public class JoinEstimation {
for (int i = 0; i < sortedJoinConditions.size(); i++) {
sel *= Math.pow(sortedJoinConditions.get(i).second, 1 /
Math.pow(2, i));
}
- return crossJoinStats.withSel(sel);
+ return crossJoinStats.updateRowCountOnly(crossJoinStats.getRowCount()
* sel);
}
private static double estimateJoinConditionSel(Statistics crossJoinStats,
Expression joinCond) {
@@ -69,13 +70,19 @@ public class JoinEstimation {
.putColumnStatistics(leftStats.columnStatistics())
.putColumnStatistics(rightStats.columnStatistics())
.build();
- List<Expression> joinConditions = join.getHashJoinConjuncts();
- Statistics innerJoinStats = estimateInnerJoin(crossJoinStats,
joinConditions);
- if (!join.getOtherJoinConjuncts().isEmpty()) {
- FilterEstimation filterEstimation = new FilterEstimation();
- innerJoinStats = filterEstimation.estimate(
- ExpressionUtils.and(join.getOtherJoinConjuncts()),
innerJoinStats);
+ Statistics innerJoinStats = null;
+ if (crossJoinStats.getRowCount() != 0) {
+ List<Expression> joinConditions = new
ArrayList<>(join.getHashJoinConjuncts());
+ joinConditions.addAll(join.getOtherJoinConjuncts());
+ innerJoinStats = estimateInnerJoin(crossJoinStats, joinConditions);
+ } else {
+ innerJoinStats = crossJoinStats;
}
+ // if (!join.getOtherJoinConjuncts().isEmpty()) {
+ // FilterEstimation filterEstimation = new FilterEstimation();
+ // innerJoinStats = filterEstimation.estimate(
+ // ExpressionUtils.and(join.getOtherJoinConjuncts()),
innerJoinStats);
+ // }
innerJoinStats.setWidth(leftStats.getWidth() + rightStats.getWidth());
innerJoinStats.setPenalty(0);
double rowCount;
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java
index 8f10d6f4ea..3143e196f5 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java
@@ -22,10 +22,12 @@ import org.apache.doris.catalog.TableIf;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
+import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
+import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
import org.apache.doris.nereids.trees.plans.algebra.EmptyRelation;
@@ -94,6 +96,7 @@ import org.apache.doris.statistics.Statistics;
import org.apache.doris.statistics.StatisticsBuilder;
import com.google.common.collect.Maps;
+import org.apache.commons.collections.CollectionUtils;
import java.util.AbstractMap.SimpleEntry;
import java.util.HashMap;
@@ -394,9 +397,24 @@ public class StatsCalculator extends
DefaultPlanVisitor<Statistics, Void> {
}
private Statistics computeFilter(Filter filter) {
- FilterEstimation filterEstimation = new FilterEstimation();
Statistics stats = groupExpression.childStatistics(0);
- return filterEstimation.estimate(filter.getPredicate(), stats);
+ Plan plan = tryToFindChild(groupExpression);
+ if (plan != null) {
+ if (plan instanceof Aggregate) {
+ Aggregate agg = ((Aggregate<?>) plan);
+ List<NamedExpression> expressions = agg.getOutputExpressions();
+ Set<Slot> slots = expressions
+ .stream()
+ .filter(Alias.class::isInstance)
+ .filter(s -> ((Alias)
s).child().anyMatch(AggregateFunction.class::isInstance))
+
.map(NamedExpression::toSlot).collect(Collectors.toSet());
+ Expression predicate = filter.getPredicate();
+ if (predicate.anyMatch(s -> slots.contains(s))) {
+ return new
FilterEstimation(slots).estimate(filter.getPredicate(), stats);
+ }
+ }
+ }
+ return new FilterEstimation().estimate(filter.getPredicate(), stats);
}
// TODO: 1. Subtract the pruned partition
@@ -441,13 +459,8 @@ public class StatsCalculator extends
DefaultPlanVisitor<Statistics, Void> {
if (!groupByExpressions.isEmpty()) {
Map<Expression, ColumnStatistic> childSlotToColumnStats =
childStats.columnStatistics();
double inputRowCount = childStats.getRowCount();
- if (inputRowCount == 0) {
- //on empty relation, Agg output 1 tuple
- resultSetCount = 1;
- } else {
+ if (inputRowCount != 0) {
List<ColumnStatistic> groupByKeyStats =
groupByExpressions.stream()
- .flatMap(expr -> expr.getInputSlots().stream())
- .map(Slot::getExprId)
.filter(childSlotToColumnStats::containsKey)
.map(childSlotToColumnStats::get)
.filter(s -> !s.isUnKnown)
@@ -692,4 +705,16 @@ public class StatsCalculator extends
DefaultPlanVisitor<Statistics, Void> {
.setAvgSizeByte(newAverageRowSize);
return columnStatisticBuilder.build();
}
+
+ private Plan tryToFindChild(GroupExpression groupExpression) {
+ List<GroupExpression> groupExprs =
groupExpression.child(0).getLogicalExpressions();
+ if (CollectionUtils.isEmpty(groupExprs)) {
+ groupExprs = groupExpression.child(0).getPhysicalExpressions();
+ if (CollectionUtils.isEmpty(groupExprs)) {
+ return null;
+ }
+ }
+ return groupExprs.get(0).getPlan();
+ }
+
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsErrorEstimator.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsErrorEstimator.java
index 6966fc97ea..54d309ac14 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsErrorEstimator.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsErrorEstimator.java
@@ -52,7 +52,7 @@ public class StatsErrorEstimator {
}
/**
- * Map plan id to stats.
+ * Invoked by PhysicalPlanTranslator, put the translated plan node and
corresponding physical plan to estimator.
*/
public void updateLegacyPlanIdToPhysicalPlan(PlanNode planNode,
AbstractPlan physicalPlan) {
Statistics statistics = physicalPlan.getStats();
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/statistics/StatisticConstants.java
b/fe/fe-core/src/main/java/org/apache/doris/statistics/StatisticConstants.java
index df34c2f9d6..0345c7930e 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/statistics/StatisticConstants.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/statistics/StatisticConstants.java
@@ -62,4 +62,6 @@ public class StatisticConstants {
public static final int LOAD_TASK_LIMITS = 10;
+ public static final double DEFAULT_INNER_JOIN_FACTOR = 0.1;
+
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/statistics/Statistics.java
b/fe/fe-core/src/main/java/org/apache/doris/statistics/Statistics.java
index b9cf6040e8..2aa0d1ad33 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/statistics/Statistics.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/statistics/Statistics.java
@@ -83,14 +83,25 @@ public class Statistics {
return rowCount;
}
+ /*
+ * Return a stats with new rowCount and fix each column stats.
+ */
public Statistics withRowCount(double rowCount) {
+ if (Double.isNaN(rowCount)) {
+ return this;
+ }
Statistics statistics = new Statistics(rowCount, new
HashMap<>(expressionToColumnStats), width, penalty);
statistics.fix(rowCount, StatsMathUtil.nonZeroDivisor(this.rowCount));
return statistics;
}
+ public Statistics updateRowCountOnly(double rowCount) {
+ return new Statistics(rowCount, expressionToColumnStats);
+ }
+
public void fix(double newRowCount, double originRowCount) {
double sel = newRowCount / originRowCount;
+
for (Entry<Expression, ColumnStatistic> entry :
expressionToColumnStats.entrySet()) {
ColumnStatistic columnStatistic = entry.getValue();
ColumnStatisticBuilder columnStatisticBuilder = new
ColumnStatisticBuilder(columnStatistic);
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/FilterEstimationTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/FilterEstimationTest.java
index 691cf53720..0ab0c7c2ff 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/FilterEstimationTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/FilterEstimationTest.java
@@ -115,9 +115,7 @@ class FilterEstimationTest {
Statistics stat = new Statistics(1000, slotToColumnStat);
FilterEstimation filterEstimation = new FilterEstimation();
Statistics expected = filterEstimation.estimate(in, stat);
- Assertions.assertEquals(
- FilterEstimation.DEFAULT_INEQUALITY_COEFFICIENT *
stat.getRowCount(),
- expected.getRowCount());
+ Assertions.assertTrue(Precision.equals(333.33, expected.getRowCount(),
0.01));
}
@Test
@@ -134,9 +132,7 @@ class FilterEstimationTest {
Statistics stat = new Statistics(1000, slotToColumnStat);
FilterEstimation filterEstimation = new FilterEstimation();
Statistics expected = filterEstimation.estimate(notIn, stat);
- Assertions.assertEquals(
- FilterEstimation.DEFAULT_INEQUALITY_COEFFICIENT *
stat.getRowCount(),
- expected.getRowCount());
+ Assertions.assertTrue(Precision.equals(333.33, expected.getRowCount(),
0.01));
}
/**
diff --git a/tools/qerror.py b/tools/qerror.py
index 428d9fdf90..70920b60a4 100644
--- a/tools/qerror.py
+++ b/tools/qerror.py
@@ -21,6 +21,7 @@ import subprocess
import requests
import json
+import time
mycli_cmd = "mysql -h127.0.0.1 -P9030 -uroot -Dtpch1G"
@@ -37,6 +38,8 @@ sql_file_prefix_for_trace = """
SET session_context='trace_id:{}';
"""
+q_err_list = []
+
def extract_number(string):
return int(''.join([c for c in string if c.isdigit()]))
@@ -73,6 +76,7 @@ def execute_sql(sql_file: str):
def get_q_error(trace_id):
+ time.sleep(1)
# 'YWRtaW46' is the base64 encoded result for 'admin:'
headers = {'Authorization': 'BASIC YWRtaW46'}
resp_wrapper = requests.get(trace_url.format(trace_id), headers=headers)
@@ -81,9 +85,13 @@ def get_q_error(trace_id):
resp_wrapper = requests.get(qerror_url.format(query_id), headers=headers)
resp_text = resp_wrapper.text
write_result(str(trace_id), resp_text)
+ print(trace_id)
+ print(resp_text)
+ qerr = json.loads(resp_text)["qError"]
+ q_err_list.append(float(qerr))
-def iterates_sqls(path: str) -> list:
+def iterates_sqls(path: str, if_write_results: bool) -> list:
cost_times = []
files = os.listdir(path)
files.sort(key=extract_number)
@@ -93,14 +101,22 @@ def iterates_sqls(path: str) -> list:
traced_sql_file = filepath + ".traced"
content = read_lines(filepath)
sql_num = extract_number(filename)
- write_results(traced_sql_file,
str(sql_file_prefix_for_trace.format(sql_num)), content)
- execute_sql(traced_sql_file)
- get_q_error(sql_num)
- os.remove(traced_sql_file)
+ print("sql num" + str(sql_num))
+ if if_write_results:
+ write_results(traced_sql_file,
str(sql_file_prefix_for_trace.format(sql_num)), content)
+ execute_sql(traced_sql_file)
+ get_q_error(sql_num)
+ os.remove(traced_sql_file)
+ else:
+ execute_sql(filepath)
return cost_times
if __name__ == '__main__':
execute_command("echo 'set global enable_nereids_planner=true' | mysql
-h127.0.0.1 -P9030")
execute_command("echo 'set global
enable_fallback_to_original_planner=false' | mysql -h127.0.0.1 -P9030")
- iterates_sqls(original_sql_dir)
+ print("Preparing")
+ iterates_sqls(original_sql_dir, False)
+ print("Started...")
+ iterates_sqls(original_sql_dir, True)
+ write_results(qerr_saved_file_path, "AVG\n", [sum(q_err_list) /
len(qerror_url)])
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]