This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new ae252d1cfa [opt](Nereids) simplify decimalv3 comparison predicate 
(#18975)
ae252d1cfa is described below

commit ae252d1cfa334f20599219832579feae6d541a89
Author: morrySnow <[email protected]>
AuthorDate: Wed Apr 26 23:57:09 2023 +0800

    [opt](Nereids) simplify decimalv3 comparison predicate (#18975)
    
    1. fix constant folding failed on decimalv3 type
    2. support reduce decimalv3 literal precision in comparison predicate
    3. support fe config enable_decimal_conversion
---
 .../doris/nereids/parser/LogicalPlanBuilder.java   | 10 ++-
 .../rules/expression/ExpressionOptimization.java   |  2 +
 .../expression/rules/FoldConstantRuleOnFE.java     |  1 +
 .../rules/SimplifyDecimalV3Comparison.java         | 77 ++++++++++++++++++++++
 .../trees/expressions/ExpressionEvaluator.java     |  9 ++-
 .../expressions/literal/DecimalV3Literal.java      |  4 +-
 .../rules/expression/ExpressionRewriteTest.java    | 21 ++++++
 7 files changed, 119 insertions(+), 5 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java
index f22ff8d63c..8f265cff78 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java
@@ -20,6 +20,7 @@ package org.apache.doris.nereids.parser;
 import org.apache.doris.analysis.ArithmeticExpr.Operator;
 import org.apache.doris.analysis.SetType;
 import org.apache.doris.analysis.UserIdentity;
+import org.apache.doris.common.Config;
 import org.apache.doris.common.DdlException;
 import org.apache.doris.common.Pair;
 import org.apache.doris.nereids.DorisParser;
@@ -185,6 +186,7 @@ import 
org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal;
 import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
 import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.Interval;
@@ -1742,8 +1744,12 @@ public class LogicalPlanBuilder extends 
DorisParserBaseVisitor<Object> {
     }
 
     @Override
-    public DecimalLiteral visitDecimalLiteral(DecimalLiteralContext ctx) {
-        return new DecimalLiteral(new BigDecimal(ctx.getText()));
+    public Literal visitDecimalLiteral(DecimalLiteralContext ctx) {
+        if (Config.enable_decimal_conversion) {
+            return new DecimalV3Literal(new BigDecimal(ctx.getText()));
+        } else {
+            return new DecimalLiteral(new BigDecimal(ctx.getText()));
+        }
     }
 
     private String parseTVFPropertyItem(TvfPropertyItemContext item) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java
index 47840f226e..50c0af4402 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java
@@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.expression;
 import org.apache.doris.nereids.rules.expression.rules.DistinctPredicatesRule;
 import org.apache.doris.nereids.rules.expression.rules.ExtractCommonFactorRule;
 import 
org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate;
+import 
org.apache.doris.nereids.rules.expression.rules.SimplifyDecimalV3Comparison;
 import org.apache.doris.nereids.rules.expression.rules.SimplifyRange;
 
 import com.google.common.collect.ImmutableList;
@@ -34,6 +35,7 @@ public class ExpressionOptimization extends ExpressionRewrite 
{
             ExtractCommonFactorRule.INSTANCE,
             DistinctPredicatesRule.INSTANCE,
             SimplifyComparisonPredicate.INSTANCE,
+            SimplifyDecimalV3Comparison.INSTANCE,
             SimplifyRange.INSTANCE
     );
     private static final ExpressionRuleExecutor EXECUTOR = new 
ExpressionRuleExecutor(OPTIMIZE_REWRITE_RULES);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java
index c83d3029ce..d729273a64 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java
@@ -79,6 +79,7 @@ import java.util.Objects;
  * evaluate an expression on fe.
  */
 public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule {
+
     public static final FoldConstantRuleOnFE INSTANCE = new 
FoldConstantRuleOnFE();
 
     @Override
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java
new file mode 100644
index 0000000000..93021f0b58
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java
@@ -0,0 +1,77 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.expression.rules;
+
+import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
+import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
+import org.apache.doris.nereids.trees.expressions.Cast;
+import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
+import org.apache.doris.nereids.types.DecimalV3Type;
+
+import com.google.common.base.Preconditions;
+
+import java.math.BigDecimal;
+
+/**
+ * if we have a column with decimalv3 type and set enable_decimal_conversion = 
false.
+ * we have a column named col1 with type decimalv3(15, 2)
+ * and we have a comparison like col1 > 0.5 + 0.1
+ * then the result type of 0.5 + 0.1 is decimalv2(27, 9)
+ * and the col1 need to convert to decimalv3(27, 9) to match the precision of 
right hand
+ * this rule simplify it from cast(col1 as decimalv3(27, 9)) > 0.6 to col1 > 
0.6
+ */
+public class SimplifyDecimalV3Comparison extends AbstractExpressionRewriteRule 
{
+
+    public static SimplifyDecimalV3Comparison INSTANCE = new 
SimplifyDecimalV3Comparison();
+
+    @Override
+    public Expression visitComparisonPredicate(ComparisonPredicate cp, 
ExpressionRewriteContext context) {
+        Expression left = rewrite(cp.left(), context);
+        Expression right = rewrite(cp.right(), context);
+
+        if (left.getDataType() instanceof DecimalV3Type
+                && left instanceof Cast
+                && ((Cast) left).child().getDataType() instanceof DecimalV3Type
+                && right instanceof DecimalV3Literal) {
+            return doProcess(cp, (Cast) left, (DecimalV3Literal) right);
+        }
+
+        if (left != cp.left() || right != cp.right()) {
+            return cp.withChildren(left, right);
+        } else {
+            return cp;
+        }
+    }
+
+    private Expression doProcess(ComparisonPredicate cp, Cast left, 
DecimalV3Literal right) {
+        BigDecimal trailingZerosValue = right.getValue().stripTrailingZeros();
+        int scale = 
org.apache.doris.analysis.DecimalLiteral.getBigDecimalScale(trailingZerosValue);
+        int precision = 
org.apache.doris.analysis.DecimalLiteral.getBigDecimalScale(trailingZerosValue);
+        Expression castChild = left.child();
+        Preconditions.checkState(castChild.getDataType() instanceof 
DecimalV3Type);
+        DecimalV3Type leftType = (DecimalV3Type) castChild.getDataType();
+        // precision and scale of literal must all smaller than left, 
otherwise we need to do cast on right.
+        Preconditions.checkState(scale <= leftType.getScale(), "right scale 
should not greater than left");
+        Preconditions.checkState(precision <= leftType.getPrecision(), "right 
precision should not greater than left");
+        DecimalV3Literal newRight = new DecimalV3Literal(
+                DecimalV3Type.createDecimalV3Type(leftType.getPrecision(), 
leftType.getScale()), trailingZerosValue);
+        return cp.withChildren(castChild, newRight);
+    }
+}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java
index 0eb013614a..f68309f26d 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java
@@ -24,6 +24,7 @@ import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunctio
 import org.apache.doris.nereids.trees.expressions.literal.Literal;
 import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
 import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.types.DecimalV3Type;
 
 import com.google.common.collect.ImmutableMultimap;
 
@@ -104,7 +105,7 @@ public enum ExpressionEvaluator {
             }
             boolean match = true;
             for (int i = 0; i < candidateTypes.length; i++) {
-                if (!candidateTypes[i].equals(expectedTypes[i])) {
+                if 
(!(expectedTypes[i].toCatalogDataType().matchesType(candidateTypes[i].toCatalogDataType())))
 {
                     match = false;
                     break;
                 }
@@ -142,7 +143,11 @@ public enum ExpressionEvaluator {
             DataType returnType = 
DataType.convertFromString(annotation.returnType());
             List<DataType> argTypes = new ArrayList<>();
             for (String type : annotation.argTypes()) {
-                argTypes.add(DataType.convertFromString(type));
+                if (type.equalsIgnoreCase("DECIMALV3")) {
+                    argTypes.add(DecimalV3Type.WILDCARD);
+                } else {
+                    argTypes.add(DataType.convertFromString(type));
+                }
             }
             FunctionSignature signature = new FunctionSignature(name,
                     argTypes.toArray(new DataType[argTypes.size()]), 
returnType);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalV3Literal.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalV3Literal.java
index 18c13c4ac5..c6c28fa3e9 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalV3Literal.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalV3Literal.java
@@ -39,7 +39,9 @@ public class DecimalV3Literal extends Literal {
 
     public DecimalV3Literal(DecimalV3Type dataType, BigDecimal value) {
         super(DecimalV3Type.createDecimalV3Type(dataType.getPrecision(), 
dataType.getScale()));
-        this.value = 
Objects.requireNonNull(value.setScale(dataType.getScale(), RoundingMode.DOWN));
+        Objects.requireNonNull(value, "value not be null");
+        BigDecimal adjustedValue = value.scale() < 0 ? value : 
value.setScale(dataType.getScale(), RoundingMode.DOWN);
+        this.value = Objects.requireNonNull(adjustedValue);
     }
 
     @Override
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java
index 0dbaf6d30c..65d9e8c6ac 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java
@@ -25,6 +25,7 @@ import 
org.apache.doris.nereids.rules.expression.rules.InPredicateToEqualToRule;
 import 
org.apache.doris.nereids.rules.expression.rules.NormalizeBinaryPredicatesRule;
 import org.apache.doris.nereids.rules.expression.rules.SimplifyCastRule;
 import 
org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate;
+import 
org.apache.doris.nereids.rules.expression.rules.SimplifyDecimalV3Comparison;
 import org.apache.doris.nereids.rules.expression.rules.SimplifyNotExprRule;
 import org.apache.doris.nereids.trees.expressions.Cast;
 import org.apache.doris.nereids.trees.expressions.EqualTo;
@@ -38,6 +39,7 @@ import 
org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.DateTimeV2Literal;
 import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal;
 import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
 import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
@@ -46,6 +48,7 @@ import 
org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
 import org.apache.doris.nereids.types.DateTimeType;
 import org.apache.doris.nereids.types.DateTimeV2Type;
 import org.apache.doris.nereids.types.DecimalV2Type;
+import org.apache.doris.nereids.types.DecimalV3Type;
 import org.apache.doris.nereids.types.StringType;
 import org.apache.doris.nereids.types.VarcharType;
 
@@ -283,4 +286,22 @@ public class ExpressionRewriteTest extends 
ExpressionRewriteTestHelper {
                 new EqualTo(dv2, dv2));
 
     }
+
+    @Test
+    public void testSimplifyDecimalV3Comparison() {
+        executor = new 
ExpressionRuleExecutor(ImmutableList.of(SimplifyDecimalV3Comparison.INSTANCE));
+
+        // do rewrite
+        Expression left = new DecimalV3Literal(new BigDecimal("12345.67"));
+        Expression cast = new Cast(left, DecimalV3Type.createDecimalV3Type(27, 
9));
+        Expression right = new 
DecimalV3Literal(DecimalV3Type.createDecimalV3Type(27, 9), new 
BigDecimal("0.01"));
+        Expression expectedRight = new 
DecimalV3Literal(DecimalV3Type.createDecimalV3Type(7, 2), new 
BigDecimal("0.01"));
+        Expression comparison = new EqualTo(cast, right);
+        Expression expected = new EqualTo(left, expectedRight);
+        assertRewrite(comparison, expected);
+
+        // not cast
+        comparison = new EqualTo(new DecimalV3Literal(new 
BigDecimal("12345.67")), new DecimalV3Literal(new BigDecimal("76543.21")));
+        assertRewrite(comparison, comparison);
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to