This is an automated email from the ASF dual-hosted git repository.

jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new f600f70619 [ehancement](fe) Tune for stats framework (#17860)
f600f70619 is described below

commit f600f7061988e6ad6182243aef6beb13dc39846e
Author: AKIRA <[email protected]>
AuthorDate: Wed Mar 22 12:07:56 2023 +0900

    [ehancement](fe) Tune for stats framework (#17860)
---
 .../doris/nereids/stats/FilterEstimation.java      | 49 +++++++++++++++++-----
 .../apache/doris/nereids/stats/JoinEstimation.java | 23 ++++++----
 .../doris/nereids/stats/StatsCalculator.java       | 41 ++++++++++++++----
 .../doris/nereids/stats/StatsErrorEstimator.java   |  2 +-
 .../doris/statistics/StatisticConstants.java       |  2 +
 .../org/apache/doris/statistics/Statistics.java    | 11 +++++
 .../doris/nereids/stats/FilterEstimationTest.java  |  8 +---
 tools/qerror.py                                    | 28 ++++++++++---
 8 files changed, 125 insertions(+), 39 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java
index e2159f1040..1aaaeaa955 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java
@@ -18,6 +18,7 @@
 package org.apache.doris.nereids.stats;
 
 import org.apache.doris.nereids.stats.FilterEstimation.EstimationContext;
+import org.apache.doris.nereids.trees.TreeNode;
 import org.apache.doris.nereids.trees.expressions.And;
 import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
 import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
@@ -33,6 +34,7 @@ import 
org.apache.doris.nereids.trees.expressions.NullSafeEqual;
 import org.apache.doris.nereids.trees.expressions.Or;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.functions.Function;
 import org.apache.doris.nereids.trees.expressions.literal.Literal;
 import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
 import org.apache.doris.statistics.Bucket;
@@ -49,6 +51,7 @@ import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.function.Predicate;
 
 /**
  * Calculate selectivity of expression that produces boolean value.
@@ -56,9 +59,21 @@ import java.util.Set;
  */
 public class FilterEstimation extends ExpressionVisitor<Statistics, 
EstimationContext> {
     public static final double DEFAULT_INEQUALITY_COEFFICIENT = 0.5;
+    public static final double DEFAULT_IN_COEFFICIENT = 1.0 / 3.0;
+
+    public static final double DEFAULT_HAVING_COEFFICIENT = 0.01;
 
     public static final double DEFAULT_EQUALITY_COMPARISON_SELECTIVITY = 0.1;
 
+    private Set<Slot> aggSlots;
+
+    public FilterEstimation() {
+    }
+
+    public FilterEstimation(Set<Slot> aggSlots) {
+        this.aggSlots = aggSlots;
+    }
+
     /**
      * This method will update the stats according to the selectivity.
      */
@@ -104,7 +119,6 @@ public class FilterEstimation extends 
ExpressionVisitor<Statistics, EstimationCo
                     
estimatedColStatsBuilder.setMaxValue(rightColStats.maxValue);
                     estimatedColStatsBuilder.setMaxExpr(rightColStats.maxExpr);
                 }
-                orStats.addColumnStats(entry.getKey(), 
estimatedColStatsBuilder.build());
             }
             return orStats;
         }
@@ -127,6 +141,24 @@ public class FilterEstimation extends 
ExpressionVisitor<Statistics, EstimationCo
         }
         ColumnStatistic statsForLeft = ExpressionEstimation.estimate(left, 
context.statistics);
         ColumnStatistic statsForRight = ExpressionEstimation.estimate(right, 
context.statistics);
+        if (aggSlots != null) {
+            Predicate<TreeNode<Expression>> containsAggSlot = e -> {
+                if (e instanceof SlotReference) {
+                    SlotReference slot = (SlotReference) e;
+                    return aggSlots.contains(slot);
+                }
+                return false;
+            };
+            boolean leftAgg = left.anyMatch(containsAggSlot);
+            boolean rightAgg = right.anyMatch(containsAggSlot);
+            // It means this predicate appears in HAVING clause.
+            if (leftAgg || rightAgg) {
+                double rowCount = context.statistics.getRowCount();
+                double newRowCount = Math.max(rowCount * 
DEFAULT_HAVING_COEFFICIENT,
+                        Math.max(statsForLeft.ndv, statsForRight.ndv));
+                return context.statistics.withRowCount(newRowCount);
+            }
+        }
         if (!(left instanceof Literal) && !(right instanceof Literal)) {
             return calculateWhenBothColumn(cp, context, statsForLeft, 
statsForRight);
         } else {
@@ -167,14 +199,11 @@ public class FilterEstimation extends 
ExpressionVisitor<Statistics, EstimationCo
         double ndv = statsForLeft.ndv;
         double val = statsForRight.maxValue;
         if (cp instanceof EqualTo || cp instanceof NullSafeEqual) {
-            if (statsForLeft == ColumnStatistic.UNKNOWN) {
-                selectivity = DEFAULT_EQUALITY_COMPARISON_SELECTIVITY;
+
+            if (val > statsForLeft.maxValue || val < statsForLeft.minValue) {
+                selectivity = 0.0;
             } else {
-                if (val > statsForLeft.maxValue || val < 
statsForLeft.minValue) {
-                    selectivity = 0.0;
-                } else {
-                    selectivity = StatsMathUtil.minNonNaN(1.0, 1.0 / ndv);
-                }
+                selectivity = StatsMathUtil.minNonNaN(1.0, 1.0 / ndv);
             }
             if (context.isNot) {
                 selectivity = 1 - selectivity;
@@ -249,8 +278,8 @@ public class FilterEstimation extends 
ExpressionVisitor<Statistics, EstimationCo
         boolean isNotIn = context != null && context.isNot;
         Expression compareExpr = inPredicate.getCompareExpr();
         ColumnStatistic compareExprStats = 
ExpressionEstimation.estimate(compareExpr, context.statistics);
-        if (compareExprStats.isUnKnown) {
-            return context.statistics.withSel(DEFAULT_INEQUALITY_COEFFICIENT);
+        if (compareExprStats.isUnKnown || compareExpr instanceof Function) {
+            return context.statistics.withSel(DEFAULT_IN_COEFFICIENT);
         }
         List<Expression> options = inPredicate.getOptions();
         double maxOption = 0;
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java
index 0e8516e00c..d1427ef469 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/JoinEstimation.java
@@ -21,10 +21,10 @@ import org.apache.doris.common.Pair;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.plans.JoinType;
 import org.apache.doris.nereids.trees.plans.algebra.Join;
-import org.apache.doris.nereids.util.ExpressionUtils;
 import org.apache.doris.statistics.Statistics;
 import org.apache.doris.statistics.StatisticsBuilder;
 
+import java.util.ArrayList;
 import java.util.List;
 import java.util.stream.Collectors;
 
@@ -33,6 +33,7 @@ import java.util.stream.Collectors;
  * TODO: Update other props in the ColumnStats properly.
  */
 public class JoinEstimation {
+
     private static Statistics estimateInnerJoin(Statistics crossJoinStats, 
List<Expression> joinConditions) {
         List<Pair<Expression, Double>> sortedJoinConditions = 
joinConditions.stream()
                 .map(expression -> Pair.of(expression, 
estimateJoinConditionSel(crossJoinStats, expression)))
@@ -51,7 +52,7 @@ public class JoinEstimation {
         for (int i = 0; i < sortedJoinConditions.size(); i++) {
             sel *= Math.pow(sortedJoinConditions.get(i).second, 1 / 
Math.pow(2, i));
         }
-        return crossJoinStats.withSel(sel);
+        return crossJoinStats.updateRowCountOnly(crossJoinStats.getRowCount() 
* sel);
     }
 
     private static double estimateJoinConditionSel(Statistics crossJoinStats, 
Expression joinCond) {
@@ -69,13 +70,19 @@ public class JoinEstimation {
                 .putColumnStatistics(leftStats.columnStatistics())
                 .putColumnStatistics(rightStats.columnStatistics())
                 .build();
-        List<Expression> joinConditions = join.getHashJoinConjuncts();
-        Statistics innerJoinStats = estimateInnerJoin(crossJoinStats, 
joinConditions);
-        if (!join.getOtherJoinConjuncts().isEmpty()) {
-            FilterEstimation filterEstimation = new FilterEstimation();
-            innerJoinStats = filterEstimation.estimate(
-                    ExpressionUtils.and(join.getOtherJoinConjuncts()), 
innerJoinStats);
+        Statistics innerJoinStats = null;
+        if (crossJoinStats.getRowCount() != 0) {
+            List<Expression> joinConditions = new 
ArrayList<>(join.getHashJoinConjuncts());
+            joinConditions.addAll(join.getOtherJoinConjuncts());
+            innerJoinStats = estimateInnerJoin(crossJoinStats, joinConditions);
+        } else {
+            innerJoinStats = crossJoinStats;
         }
+        // if (!join.getOtherJoinConjuncts().isEmpty()) {
+        //     FilterEstimation filterEstimation = new FilterEstimation();
+        //     innerJoinStats = filterEstimation.estimate(
+        //             ExpressionUtils.and(join.getOtherJoinConjuncts()), 
innerJoinStats);
+        // }
         innerJoinStats.setWidth(leftStats.getWidth() + rightStats.getWidth());
         innerJoinStats.setPenalty(0);
         double rowCount;
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java
index 8f10d6f4ea..3143e196f5 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java
@@ -22,10 +22,12 @@ import org.apache.doris.catalog.TableIf;
 import org.apache.doris.common.Pair;
 import org.apache.doris.nereids.memo.Group;
 import org.apache.doris.nereids.memo.GroupExpression;
+import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
 import org.apache.doris.nereids.trees.plans.algebra.EmptyRelation;
@@ -94,6 +96,7 @@ import org.apache.doris.statistics.Statistics;
 import org.apache.doris.statistics.StatisticsBuilder;
 
 import com.google.common.collect.Maps;
+import org.apache.commons.collections.CollectionUtils;
 
 import java.util.AbstractMap.SimpleEntry;
 import java.util.HashMap;
@@ -394,9 +397,24 @@ public class StatsCalculator extends 
DefaultPlanVisitor<Statistics, Void> {
     }
 
     private Statistics computeFilter(Filter filter) {
-        FilterEstimation filterEstimation = new FilterEstimation();
         Statistics stats = groupExpression.childStatistics(0);
-        return filterEstimation.estimate(filter.getPredicate(), stats);
+        Plan plan = tryToFindChild(groupExpression);
+        if (plan != null) {
+            if (plan instanceof Aggregate) {
+                Aggregate agg = ((Aggregate<?>) plan);
+                List<NamedExpression> expressions = agg.getOutputExpressions();
+                Set<Slot> slots = expressions
+                        .stream()
+                        .filter(Alias.class::isInstance)
+                        .filter(s -> ((Alias) 
s).child().anyMatch(AggregateFunction.class::isInstance))
+                        
.map(NamedExpression::toSlot).collect(Collectors.toSet());
+                Expression predicate = filter.getPredicate();
+                if (predicate.anyMatch(s -> slots.contains(s))) {
+                    return new 
FilterEstimation(slots).estimate(filter.getPredicate(), stats);
+                }
+            }
+        }
+        return new FilterEstimation().estimate(filter.getPredicate(), stats);
     }
 
     // TODO: 1. Subtract the pruned partition
@@ -441,13 +459,8 @@ public class StatsCalculator extends 
DefaultPlanVisitor<Statistics, Void> {
         if (!groupByExpressions.isEmpty()) {
             Map<Expression, ColumnStatistic> childSlotToColumnStats = 
childStats.columnStatistics();
             double inputRowCount = childStats.getRowCount();
-            if (inputRowCount == 0) {
-                //on empty relation, Agg output 1 tuple
-                resultSetCount = 1;
-            } else {
+            if (inputRowCount != 0) {
                 List<ColumnStatistic> groupByKeyStats = 
groupByExpressions.stream()
-                        .flatMap(expr -> expr.getInputSlots().stream())
-                        .map(Slot::getExprId)
                         .filter(childSlotToColumnStats::containsKey)
                         .map(childSlotToColumnStats::get)
                         .filter(s -> !s.isUnKnown)
@@ -692,4 +705,16 @@ public class StatsCalculator extends 
DefaultPlanVisitor<Statistics, Void> {
                 .setAvgSizeByte(newAverageRowSize);
         return columnStatisticBuilder.build();
     }
+
+    private Plan tryToFindChild(GroupExpression groupExpression) {
+        List<GroupExpression> groupExprs = 
groupExpression.child(0).getLogicalExpressions();
+        if (CollectionUtils.isEmpty(groupExprs)) {
+            groupExprs = groupExpression.child(0).getPhysicalExpressions();
+            if (CollectionUtils.isEmpty(groupExprs)) {
+                return null;
+            }
+        }
+        return groupExprs.get(0).getPlan();
+    }
+
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsErrorEstimator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsErrorEstimator.java
index 6966fc97ea..54d309ac14 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsErrorEstimator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsErrorEstimator.java
@@ -52,7 +52,7 @@ public class StatsErrorEstimator {
     }
 
     /**
-     * Map plan id to stats.
+     * Invoked by PhysicalPlanTranslator, put the translated plan node and 
corresponding physical plan to estimator.
      */
     public void updateLegacyPlanIdToPhysicalPlan(PlanNode planNode, 
AbstractPlan physicalPlan) {
         Statistics statistics = physicalPlan.getStats();
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/statistics/StatisticConstants.java 
b/fe/fe-core/src/main/java/org/apache/doris/statistics/StatisticConstants.java
index df34c2f9d6..0345c7930e 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/statistics/StatisticConstants.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/statistics/StatisticConstants.java
@@ -62,4 +62,6 @@ public class StatisticConstants {
 
     public static final int LOAD_TASK_LIMITS = 10;
 
+    public static final double DEFAULT_INNER_JOIN_FACTOR = 0.1;
+
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/statistics/Statistics.java 
b/fe/fe-core/src/main/java/org/apache/doris/statistics/Statistics.java
index b9cf6040e8..2aa0d1ad33 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/statistics/Statistics.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/statistics/Statistics.java
@@ -83,14 +83,25 @@ public class Statistics {
         return rowCount;
     }
 
+    /*
+     * Return a stats with new rowCount and fix each column stats.
+     */
     public Statistics withRowCount(double rowCount) {
+        if (Double.isNaN(rowCount)) {
+            return this;
+        }
         Statistics statistics = new Statistics(rowCount, new 
HashMap<>(expressionToColumnStats), width, penalty);
         statistics.fix(rowCount, StatsMathUtil.nonZeroDivisor(this.rowCount));
         return statistics;
     }
 
+    public Statistics updateRowCountOnly(double rowCount) {
+        return new Statistics(rowCount, expressionToColumnStats);
+    }
+
     public void fix(double newRowCount, double originRowCount) {
         double sel = newRowCount / originRowCount;
+
         for (Entry<Expression, ColumnStatistic> entry : 
expressionToColumnStats.entrySet()) {
             ColumnStatistic columnStatistic = entry.getValue();
             ColumnStatisticBuilder columnStatisticBuilder = new 
ColumnStatisticBuilder(columnStatistic);
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/FilterEstimationTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/FilterEstimationTest.java
index 691cf53720..0ab0c7c2ff 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/FilterEstimationTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/FilterEstimationTest.java
@@ -115,9 +115,7 @@ class FilterEstimationTest {
         Statistics stat = new Statistics(1000, slotToColumnStat);
         FilterEstimation filterEstimation = new FilterEstimation();
         Statistics expected = filterEstimation.estimate(in, stat);
-        Assertions.assertEquals(
-                FilterEstimation.DEFAULT_INEQUALITY_COEFFICIENT * 
stat.getRowCount(),
-                expected.getRowCount());
+        Assertions.assertTrue(Precision.equals(333.33, expected.getRowCount(), 
0.01));
     }
 
     @Test
@@ -134,9 +132,7 @@ class FilterEstimationTest {
         Statistics stat = new Statistics(1000, slotToColumnStat);
         FilterEstimation filterEstimation = new FilterEstimation();
         Statistics expected = filterEstimation.estimate(notIn, stat);
-        Assertions.assertEquals(
-                FilterEstimation.DEFAULT_INEQUALITY_COEFFICIENT * 
stat.getRowCount(),
-                expected.getRowCount());
+        Assertions.assertTrue(Precision.equals(333.33, expected.getRowCount(), 
0.01));
     }
 
     /**
diff --git a/tools/qerror.py b/tools/qerror.py
index 428d9fdf90..70920b60a4 100644
--- a/tools/qerror.py
+++ b/tools/qerror.py
@@ -21,6 +21,7 @@ import subprocess
 
 import requests
 import json
+import time
 
 mycli_cmd = "mysql -h127.0.0.1 -P9030 -uroot -Dtpch1G"
 
@@ -37,6 +38,8 @@ sql_file_prefix_for_trace = """
     SET session_context='trace_id:{}';
 """
 
+q_err_list = []
+
 
 def extract_number(string):
     return int(''.join([c for c in string if c.isdigit()]))
@@ -73,6 +76,7 @@ def execute_sql(sql_file: str):
 
 
 def get_q_error(trace_id):
+    time.sleep(1)
     # 'YWRtaW46' is the base64 encoded result for 'admin:'
     headers = {'Authorization': 'BASIC YWRtaW46'}
     resp_wrapper = requests.get(trace_url.format(trace_id), headers=headers)
@@ -81,9 +85,13 @@ def get_q_error(trace_id):
     resp_wrapper = requests.get(qerror_url.format(query_id), headers=headers)
     resp_text = resp_wrapper.text
     write_result(str(trace_id), resp_text)
+    print(trace_id)
+    print(resp_text)
+    qerr = json.loads(resp_text)["qError"]
+    q_err_list.append(float(qerr))
 
 
-def iterates_sqls(path: str) -> list:
+def iterates_sqls(path: str, if_write_results: bool) -> list:
     cost_times = []
     files = os.listdir(path)
     files.sort(key=extract_number)
@@ -93,14 +101,22 @@ def iterates_sqls(path: str) -> list:
             traced_sql_file = filepath + ".traced"
             content = read_lines(filepath)
             sql_num = extract_number(filename)
-            write_results(traced_sql_file, 
str(sql_file_prefix_for_trace.format(sql_num)), content)
-            execute_sql(traced_sql_file)
-            get_q_error(sql_num)
-            os.remove(traced_sql_file)
+            print("sql num" + str(sql_num))
+            if if_write_results:
+                write_results(traced_sql_file, 
str(sql_file_prefix_for_trace.format(sql_num)), content)
+                execute_sql(traced_sql_file)
+                get_q_error(sql_num)
+                os.remove(traced_sql_file)
+            else:
+                execute_sql(filepath)
     return cost_times
 
 
 if __name__ == '__main__':
     execute_command("echo 'set global enable_nereids_planner=true' | mysql 
-h127.0.0.1 -P9030")
     execute_command("echo 'set global 
enable_fallback_to_original_planner=false' | mysql -h127.0.0.1 -P9030")
-    iterates_sqls(original_sql_dir)
+    print("Preparing")
+    iterates_sqls(original_sql_dir, False)
+    print("Started...")
+    iterates_sqls(original_sql_dir, True)
+    write_results(qerr_saved_file_path, "AVG\n", [sum(q_err_list) / 
len(qerror_url)])


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to