This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-2.0 by this push:
     new 5b83115e87b [pick](nereids) fix bug in case-when/if stats estimation 
(#30265) (#30691)
5b83115e87b is described below

commit 5b83115e87b3a1d4c802f16e64f3e9476f30ae2c
Author: minghong <[email protected]>
AuthorDate: Fri Feb 2 17:27:36 2024 +0800

    [pick](nereids) fix bug in case-when/if stats estimation (#30265) (#30691)
    
    pick #30265
---
 .../doris/nereids/stats/ExpressionEstimation.java  | 141 +++++++++++++++------
 .../nereids/stats/ExpressionEstimationTest.java    | 134 ++++++++++++++++++++
 .../doris/nereids/stats/FilterEstimationTest.java  |   2 +
 3 files changed, 236 insertions(+), 41 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java
index f231126417e..8972a26b032 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java
@@ -18,6 +18,11 @@
 package org.apache.doris.nereids.stats;
 
 import org.apache.doris.analysis.ArithmeticExpr.Operator;
+import org.apache.doris.analysis.DecimalLiteral;
+import org.apache.doris.analysis.FloatLiteral;
+import org.apache.doris.analysis.IntLiteral;
+import org.apache.doris.analysis.LargeIntLiteral;
+import org.apache.doris.analysis.LiteralExpr;
 import org.apache.doris.analysis.StringLiteral;
 import org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.trees.expressions.Add;
@@ -38,6 +43,7 @@ import 
org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.Subtract;
 import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
 import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
+import org.apache.doris.nereids.trees.expressions.WhenClause;
 import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
@@ -136,21 +142,37 @@ public class ExpressionEstimation extends 
ExpressionVisitor<ColumnStatistic, Sta
     //TODO: case-when need to re-implemented
     @Override
     public ColumnStatistic visitCaseWhen(CaseWhen caseWhen, Statistics 
context) {
+        double ndv = caseWhen.getWhenClauses().size();
+        if (caseWhen.getDefaultValue().isPresent()) {
+            ndv += 1;
+        }
+        for (WhenClause clause : caseWhen.getWhenClauses()) {
+            ColumnStatistic colStats = 
ExpressionEstimation.estimate(clause.getResult(), context);
+            ndv = Math.max(ndv, colStats.ndv);
+        }
+        if (caseWhen.getDefaultValue().isPresent()) {
+            ColumnStatistic colStats = 
ExpressionEstimation.estimate(caseWhen.getDefaultValue().get(), context);
+            ndv = Math.max(ndv, colStats.ndv);
+        }
         return new ColumnStatisticBuilder()
-                .setNdv(caseWhen.getWhenClauses().size() + 1)
-                .setMinValue(0)
-                .setMaxValue(Double.MAX_VALUE)
+                .setNdv(ndv)
+                .setMinValue(Double.NEGATIVE_INFINITY)
+                .setMaxValue(Double.POSITIVE_INFINITY)
                 .setAvgSizeByte(8)
                 .setNumNulls(0)
                 .build();
     }
 
     @Override
-    public ColumnStatistic visitIf(If function, Statistics context) {
-        // TODO: copy from visitCaseWhen, polish them.
+    public ColumnStatistic visitIf(If ifClause, Statistics context) {
+        double ndv = 2;
+        ColumnStatistic colStatsThen = 
ExpressionEstimation.estimate(ifClause.child(1), context);
+        ndv = Math.max(ndv, colStatsThen.ndv);
+        ColumnStatistic colStatsElse = 
ExpressionEstimation.estimate(ifClause.child(2), context);
+        ndv = Math.max(ndv, colStatsElse.ndv);
         return new ColumnStatisticBuilder()
-                .setNdv(2)
-                .setMinValue(0)
+                .setNdv(ndv)
+                .setMinValue(Double.NEGATIVE_INFINITY)
                 .setMaxValue(Double.POSITIVE_INFINITY)
                 .setAvgSizeByte(8)
                 .setNumNulls(0)
@@ -169,35 +191,55 @@ public class ExpressionEstimation extends 
ExpressionVisitor<ColumnStatistic, Sta
     }
 
     private ColumnStatistic castMinMax(ColumnStatistic colStats, DataType 
targetType) {
-        if (colStats.minExpr instanceof StringLiteral || colStats.maxExpr 
instanceof StringLiteral) {
-            if (targetType.isDateLikeType()) {
-                ColumnStatisticBuilder builder = new 
ColumnStatisticBuilder(colStats);
-                if (colStats.minExpr != null) {
-                    try {
-                        String strMin = colStats.minExpr.getStringValue();
-                        DateLiteral dateMinLiteral = new DateLiteral(strMin);
-                        long min = dateMinLiteral.getValue();
-                        builder.setMinValue(min);
-                        builder.setMinExpr(dateMinLiteral.toLegacyLiteral());
-                    } catch (AnalysisException e) {
-                        // ignore exception. do not convert min
-                    }
+        // cast str to date/datetime
+        if (colStats.minExpr instanceof StringLiteral
+                && colStats.maxExpr instanceof StringLiteral
+                && targetType.isDateLikeType()) {
+            boolean convertSuccess = true;
+            ColumnStatisticBuilder builder = new 
ColumnStatisticBuilder(colStats);
+            if (colStats.minExpr != null) {
+                try {
+                    String strMin = colStats.minExpr.getStringValue();
+                    DateLiteral dateMinLiteral = new DateLiteral(strMin);
+                    long min = dateMinLiteral.getValue();
+                    builder.setMinValue(min);
+                    builder.setMinExpr(dateMinLiteral.toLegacyLiteral());
+                } catch (AnalysisException e) {
+                    convertSuccess = false;
                 }
-                if (colStats.maxExpr != null) {
-                    try {
-                        String strMax = colStats.maxExpr.getStringValue();
-                        DateLiteral dateMaxLiteral = new DateLiteral(strMax);
-                        long max = dateMaxLiteral.getValue();
-                        builder.setMaxValue(max);
-                        builder.setMaxExpr(dateMaxLiteral.toLegacyLiteral());
-                    } catch (AnalysisException e) {
-                        // ignore exception. do not convert max
-                    }
+            }
+            if (convertSuccess && colStats.maxExpr != null) {
+                try {
+                    String strMax = colStats.maxExpr.getStringValue();
+                    DateLiteral dateMaxLiteral = new DateLiteral(strMax);
+                    long max = dateMaxLiteral.getValue();
+                    builder.setMaxValue(max);
+                    builder.setMaxExpr(dateMaxLiteral.toLegacyLiteral());
+                } catch (AnalysisException e) {
+                    convertSuccess = false;
                 }
+            }
+            if (convertSuccess) {
                 return builder.build();
             }
         }
-        return colStats;
+        // cast numeric to numeric
+        if (isNumericLiteralExpr(colStats.minExpr) && 
isNumericLiteralExpr(colStats.maxExpr)) {
+            if (targetType.isNumericType()) {
+                return colStats;
+            }
+        }
+
+        // cast other date types, set min/max infinity
+        ColumnStatisticBuilder builder = new ColumnStatisticBuilder(colStats);
+        builder.setMinExpr(null).setMinValue(Double.NEGATIVE_INFINITY)
+                .setMaxExpr(null).setMaxValue(Double.POSITIVE_INFINITY);
+        return builder.build();
+    }
+
+    private boolean isNumericLiteralExpr(LiteralExpr literal) {
+        return literal instanceof DecimalLiteral || literal instanceof 
FloatLiteral
+                || literal instanceof IntLiteral || literal instanceof 
LargeIntLiteral;
     }
 
     @Override
@@ -561,13 +603,22 @@ public class ExpressionEstimation extends 
ExpressionVisitor<ColumnStatistic, Sta
         if (childColumnStats.minOrMaxIsInf()) {
             return columnStatisticBuilder.build();
         }
-        double minValue = getDatetimeFromLong((long) 
childColumnStats.minValue).toLocalDate()
-                .atStartOfDay(ZoneId.systemDefault()).toEpochSecond();
-        double maxValue = getDatetimeFromLong((long) 
childColumnStats.maxValue).toLocalDate()
-                .atStartOfDay(ZoneId.systemDefault()).toEpochSecond();
+        double minValue;
+        double maxValue;
+        try {
+            // min/max value is infinite, but they may be too large to convert 
to date
+            minValue = getDatetimeFromLong((long) 
childColumnStats.minValue).toLocalDate()
+                    .atStartOfDay(ZoneId.systemDefault()).toEpochSecond();
+            maxValue = getDatetimeFromLong((long) 
childColumnStats.maxValue).toLocalDate()
+                    .atStartOfDay(ZoneId.systemDefault()).toEpochSecond();
+        } catch (Exception e) {
+            // ignore DateTimeException
+            minValue = Double.NEGATIVE_INFINITY;
+            maxValue = Double.POSITIVE_INFINITY;
+        }
         return columnStatisticBuilder.setMaxValue(maxValue)
-                .setMinValue(minValue)
-                .build();
+                .setMinValue(minValue).build();
+
     }
 
     private LocalDateTime getDatetimeFromLong(long dateTime) {
@@ -583,10 +634,18 @@ public class ExpressionEstimation extends 
ExpressionVisitor<ColumnStatistic, Sta
         if (childColumnStats.minOrMaxIsInf()) {
             return columnStatisticBuilder.build();
         }
-        double minValue = getDatetimeFromLong((long) 
childColumnStats.minValue).toLocalDate().toEpochDay()
-                + (double) DAYS_FROM_0_TO_1970;
-        double maxValue = getDatetimeFromLong((long) 
childColumnStats.maxValue).toLocalDate().toEpochDay()
-                + (double) DAYS_FROM_0_TO_1970;
+        double minValue;
+        double maxValue;
+        try {
+            minValue = getDatetimeFromLong((long) 
childColumnStats.minValue).toLocalDate().toEpochDay()
+                    + (double) DAYS_FROM_0_TO_1970;
+            maxValue = getDatetimeFromLong((long) 
childColumnStats.maxValue).toLocalDate().toEpochDay()
+                    + (double) DAYS_FROM_0_TO_1970;
+        } catch (Exception e) {
+            // ignore DateTimeException
+            minValue = Double.NEGATIVE_INFINITY;
+            maxValue = Double.POSITIVE_INFINITY;
+        }
         return columnStatisticBuilder.setMaxValue(maxValue)
                 .setMinValue(minValue)
                 .build();
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/ExpressionEstimationTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/ExpressionEstimationTest.java
index b55b266faff..1748802e4dd 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/ExpressionEstimationTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/ExpressionEstimationTest.java
@@ -17,15 +17,25 @@
 
 package org.apache.doris.nereids.stats;
 
+import org.apache.doris.analysis.DateLiteral;
+import org.apache.doris.analysis.StringLiteral;
 import org.apache.doris.nereids.trees.expressions.Add;
+import org.apache.doris.nereids.trees.expressions.CaseWhen;
+import org.apache.doris.nereids.trees.expressions.Cast;
 import org.apache.doris.nereids.trees.expressions.Divide;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.Multiply;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.Subtract;
+import org.apache.doris.nereids.trees.expressions.WhenClause;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
+import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
+import org.apache.doris.nereids.types.DateType;
+import org.apache.doris.nereids.types.DoubleType;
 import org.apache.doris.nereids.types.IntegerType;
+import org.apache.doris.nereids.types.StringType;
 import org.apache.doris.statistics.ColumnStatistic;
 import org.apache.doris.statistics.ColumnStatisticBuilder;
 import org.apache.doris.statistics.Statistics;
@@ -34,7 +44,9 @@ import org.apache.commons.math3.util.Precision;
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
+import java.util.ArrayList;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 
 class ExpressionEstimationTest {
@@ -250,4 +262,126 @@ class ExpressionEstimationTest {
         Assertions.assertTrue(Precision.equals(0.1, estimated.minValue, 
0.001));
         Assertions.assertEquals(2, estimated.maxValue);
     }
+
+    // cast(str to double) = double
+    @Test
+    public void testCastStrToDouble() {
+        SlotReference a = new SlotReference("a", StringType.INSTANCE);
+        Map<Expression, ColumnStatistic> slotToColumnStat = new HashMap<>();
+        ColumnStatisticBuilder builder = new ColumnStatisticBuilder()
+                .setNdv(100)
+                .setMinExpr(new StringLiteral("01"))
+                .setMinValue(13333333)
+                .setMaxExpr(new StringLiteral("A9"))
+                .setMaxValue(23333333);
+        slotToColumnStat.put(a, builder.build());
+        Statistics stats = new Statistics(1000, slotToColumnStat);
+        Cast cast = new Cast(a, DoubleType.INSTANCE);
+        ColumnStatistic est = ExpressionEstimation.estimate(cast, stats);
+        Assertions.assertTrue(Double.isInfinite(est.minValue));
+        Assertions.assertTrue(Double.isInfinite(est.maxValue));
+        Assertions.assertNull(est.minExpr);
+        Assertions.assertNull(est.maxExpr);
+    }
+
+    // cast(str to date) = date
+    // both min and max can be converted to date
+    @Test
+    public void testCastStrToDateSuccess() {
+        SlotReference a = new SlotReference("a", StringType.INSTANCE);
+        Map<Expression, ColumnStatistic> slotToColumnStat = new HashMap<>();
+        ColumnStatisticBuilder builder = new ColumnStatisticBuilder()
+                .setNdv(100)
+                .setMinExpr(new StringLiteral("2020-01-01"))
+                .setMinValue(20200101000000.0)
+                .setMaxExpr(new StringLiteral("2021-01-01"))
+                .setMaxValue(20210101000000.0);
+        slotToColumnStat.put(a, builder.build());
+        Statistics stats = new Statistics(1000, slotToColumnStat);
+        Cast cast = new Cast(a, DateType.INSTANCE);
+        ColumnStatistic est = ExpressionEstimation.estimate(cast, stats);
+        Assertions.assertTrue(est.minExpr instanceof DateLiteral);
+        Assertions.assertTrue(est.maxExpr instanceof DateLiteral);
+        Assertions.assertEquals(est.minValue, 20200101000000.0);
+        Assertions.assertEquals(est.maxValue, 20210101000000.0);
+    }
+
+    // cast(str to date) = date
+    // min or max cannot be converted to date
+    @Test
+    public void testCastStrToDateFail() {
+        SlotReference a = new SlotReference("a", StringType.INSTANCE);
+        Map<Expression, ColumnStatistic> slotToColumnStat = new HashMap<>();
+        ColumnStatisticBuilder builder = new ColumnStatisticBuilder()
+                .setNdv(100)
+                .setMinExpr(new StringLiteral("2020-01-01"))
+                .setMinValue(20200101000000.0)
+                .setMaxExpr(new StringLiteral("2021abcdefg"))
+                .setMaxValue(20210101000000.0);
+        slotToColumnStat.put(a, builder.build());
+        Statistics stats = new Statistics(1000, slotToColumnStat);
+        Cast cast = new Cast(a, DateType.INSTANCE);
+        ColumnStatistic est = ExpressionEstimation.estimate(cast, stats);
+        Assertions.assertTrue(Double.isInfinite(est.minValue));
+        Assertions.assertTrue(Double.isInfinite(est.maxValue));
+        Assertions.assertNull(est.minExpr);
+        Assertions.assertNull(est.maxExpr);
+    }
+
+    @Test
+    public void testCaseWhen() {
+        SlotReference a = new SlotReference("a", StringType.INSTANCE);
+        Map<Expression, ColumnStatistic> slotToColumnStat = new HashMap<>();
+        ColumnStatisticBuilder builder = new ColumnStatisticBuilder()
+                .setNdv(100)
+                .setMinExpr(new StringLiteral("2020-01-01"))
+                .setMinValue(20200101000000.0)
+                .setMaxExpr(new StringLiteral("2021abcdefg"))
+                .setMaxValue(20210101000000.0);
+        slotToColumnStat.put(a, builder.build());
+        SlotReference b = new SlotReference("b", StringType.INSTANCE);
+        builder = new ColumnStatisticBuilder()
+                .setNdv(10)
+                .setMinExpr(new StringLiteral("2020-01-01"))
+                .setMinValue(20200101000000.0)
+                .setMaxExpr(new StringLiteral("2021abcdefg"))
+                .setMaxValue(20210101000000.0);
+        slotToColumnStat.put(b, builder.build());
+        Statistics stats = new Statistics(1000, slotToColumnStat);
+
+        WhenClause when1 = new WhenClause(BooleanLiteral.TRUE, a);
+        WhenClause when2 = new WhenClause(BooleanLiteral.FALSE, b);
+        List<WhenClause> whens = new ArrayList<>();
+        whens.add(when1);
+        whens.add(when2);
+        CaseWhen caseWhen = new CaseWhen(whens);
+        ColumnStatistic est = ExpressionEstimation.estimate(caseWhen, stats);
+        Assertions.assertEquals(est.ndv, 100);
+    }
+
+    @Test
+    public void testIf() {
+        SlotReference a = new SlotReference("a", StringType.INSTANCE);
+        Map<Expression, ColumnStatistic> slotToColumnStat = new HashMap<>();
+        ColumnStatisticBuilder builder = new ColumnStatisticBuilder()
+                .setNdv(100)
+                .setMinExpr(new StringLiteral("2020-01-01"))
+                .setMinValue(20200101000000.0)
+                .setMaxExpr(new StringLiteral("2021abcdefg"))
+                .setMaxValue(20210101000000.0);
+        slotToColumnStat.put(a, builder.build());
+        SlotReference b = new SlotReference("b", StringType.INSTANCE);
+        builder = new ColumnStatisticBuilder()
+                .setNdv(10)
+                .setMinExpr(new StringLiteral("2020-01-01"))
+                .setMinValue(20200101000000.0)
+                .setMaxExpr(new StringLiteral("2021abcdefg"))
+                .setMaxValue(20210101000000.0);
+        slotToColumnStat.put(b, builder.build());
+        Statistics stats = new Statistics(1000, slotToColumnStat);
+
+        If ifClause = new If(BooleanLiteral.TRUE, a, b);
+        ColumnStatistic est = ExpressionEstimation.estimate(ifClause, stats);
+        Assertions.assertEquals(est.ndv, 100);
+    }
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/FilterEstimationTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/FilterEstimationTest.java
index 177fac64f16..66e64145901 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/FilterEstimationTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/FilterEstimationTest.java
@@ -837,6 +837,8 @@ class FilterEstimationTest {
                 .setNumNulls(0)
                 .setMaxValue(100)
                 .setMinValue(0)
+                .setMaxExpr(new IntLiteral(100))
+                .setMinExpr(new IntLiteral(0))
                 .setCount(100);
         DoubleLiteral begin = new DoubleLiteral(40.0);
         DoubleLiteral end = new DoubleLiteral(50.0);


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to