This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new ae252d1cfa [opt](Nereids) simplify decimalv3 comparison predicate
(#18975)
ae252d1cfa is described below
commit ae252d1cfa334f20599219832579feae6d541a89
Author: morrySnow <[email protected]>
AuthorDate: Wed Apr 26 23:57:09 2023 +0800
[opt](Nereids) simplify decimalv3 comparison predicate (#18975)
1. fix constant folding failed on decimalv3 type
2. support reduce decimalv3 literal precision in comparison predicate
3. support fe config enable_decimal_conversion
---
.../doris/nereids/parser/LogicalPlanBuilder.java | 10 ++-
.../rules/expression/ExpressionOptimization.java | 2 +
.../expression/rules/FoldConstantRuleOnFE.java | 1 +
.../rules/SimplifyDecimalV3Comparison.java | 77 ++++++++++++++++++++++
.../trees/expressions/ExpressionEvaluator.java | 9 ++-
.../expressions/literal/DecimalV3Literal.java | 4 +-
.../rules/expression/ExpressionRewriteTest.java | 21 ++++++
7 files changed, 119 insertions(+), 5 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java
index f22ff8d63c..8f265cff78 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java
@@ -20,6 +20,7 @@ package org.apache.doris.nereids.parser;
import org.apache.doris.analysis.ArithmeticExpr.Operator;
import org.apache.doris.analysis.SetType;
import org.apache.doris.analysis.UserIdentity;
+import org.apache.doris.common.Config;
import org.apache.doris.common.DdlException;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.DorisParser;
@@ -185,6 +186,7 @@ import
org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal;
import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Interval;
@@ -1742,8 +1744,12 @@ public class LogicalPlanBuilder extends
DorisParserBaseVisitor<Object> {
}
@Override
- public DecimalLiteral visitDecimalLiteral(DecimalLiteralContext ctx) {
- return new DecimalLiteral(new BigDecimal(ctx.getText()));
+ public Literal visitDecimalLiteral(DecimalLiteralContext ctx) {
+ if (Config.enable_decimal_conversion) {
+ return new DecimalV3Literal(new BigDecimal(ctx.getText()));
+ } else {
+ return new DecimalLiteral(new BigDecimal(ctx.getText()));
+ }
}
private String parseTVFPropertyItem(TvfPropertyItemContext item) {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java
index 47840f226e..50c0af4402 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java
@@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.expression;
import org.apache.doris.nereids.rules.expression.rules.DistinctPredicatesRule;
import org.apache.doris.nereids.rules.expression.rules.ExtractCommonFactorRule;
import
org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate;
+import
org.apache.doris.nereids.rules.expression.rules.SimplifyDecimalV3Comparison;
import org.apache.doris.nereids.rules.expression.rules.SimplifyRange;
import com.google.common.collect.ImmutableList;
@@ -34,6 +35,7 @@ public class ExpressionOptimization extends ExpressionRewrite
{
ExtractCommonFactorRule.INSTANCE,
DistinctPredicatesRule.INSTANCE,
SimplifyComparisonPredicate.INSTANCE,
+ SimplifyDecimalV3Comparison.INSTANCE,
SimplifyRange.INSTANCE
);
private static final ExpressionRuleExecutor EXECUTOR = new
ExpressionRuleExecutor(OPTIMIZE_REWRITE_RULES);
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java
index c83d3029ce..d729273a64 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java
@@ -79,6 +79,7 @@ import java.util.Objects;
* evaluate an expression on fe.
*/
public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule {
+
public static final FoldConstantRuleOnFE INSTANCE = new
FoldConstantRuleOnFE();
@Override
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java
new file mode 100644
index 0000000000..93021f0b58
--- /dev/null
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java
@@ -0,0 +1,77 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.expression.rules;
+
+import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
+import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
+import org.apache.doris.nereids.trees.expressions.Cast;
+import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
+import org.apache.doris.nereids.types.DecimalV3Type;
+
+import com.google.common.base.Preconditions;
+
+import java.math.BigDecimal;
+
+/**
+ * if we have a column with decimalv3 type and set enable_decimal_conversion =
false.
+ * we have a column named col1 with type decimalv3(15, 2)
+ * and we have a comparison like col1 > 0.5 + 0.1
+ * then the result type of 0.5 + 0.1 is decimalv2(27, 9)
+ * and the col1 need to convert to decimalv3(27, 9) to match the precision of
right hand
+ * this rule simplify it from cast(col1 as decimalv3(27, 9)) > 0.6 to col1 >
0.6
+ */
+public class SimplifyDecimalV3Comparison extends AbstractExpressionRewriteRule
{
+
+ public static SimplifyDecimalV3Comparison INSTANCE = new
SimplifyDecimalV3Comparison();
+
+ @Override
+ public Expression visitComparisonPredicate(ComparisonPredicate cp,
ExpressionRewriteContext context) {
+ Expression left = rewrite(cp.left(), context);
+ Expression right = rewrite(cp.right(), context);
+
+ if (left.getDataType() instanceof DecimalV3Type
+ && left instanceof Cast
+ && ((Cast) left).child().getDataType() instanceof DecimalV3Type
+ && right instanceof DecimalV3Literal) {
+ return doProcess(cp, (Cast) left, (DecimalV3Literal) right);
+ }
+
+ if (left != cp.left() || right != cp.right()) {
+ return cp.withChildren(left, right);
+ } else {
+ return cp;
+ }
+ }
+
+ private Expression doProcess(ComparisonPredicate cp, Cast left,
DecimalV3Literal right) {
+ BigDecimal trailingZerosValue = right.getValue().stripTrailingZeros();
+ int scale =
org.apache.doris.analysis.DecimalLiteral.getBigDecimalScale(trailingZerosValue);
+ int precision =
org.apache.doris.analysis.DecimalLiteral.getBigDecimalScale(trailingZerosValue);
+ Expression castChild = left.child();
+ Preconditions.checkState(castChild.getDataType() instanceof
DecimalV3Type);
+ DecimalV3Type leftType = (DecimalV3Type) castChild.getDataType();
+ // precision and scale of literal must all smaller than left,
otherwise we need to do cast on right.
+ Preconditions.checkState(scale <= leftType.getScale(), "right scale
should not greater than left");
+ Preconditions.checkState(precision <= leftType.getPrecision(), "right
precision should not greater than left");
+ DecimalV3Literal newRight = new DecimalV3Literal(
+ DecimalV3Type.createDecimalV3Type(leftType.getPrecision(),
leftType.getScale()), trailingZerosValue);
+ return cp.withChildren(castChild, newRight);
+ }
+}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java
index 0eb013614a..f68309f26d 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java
@@ -24,6 +24,7 @@ import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunctio
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.types.DecimalV3Type;
import com.google.common.collect.ImmutableMultimap;
@@ -104,7 +105,7 @@ public enum ExpressionEvaluator {
}
boolean match = true;
for (int i = 0; i < candidateTypes.length; i++) {
- if (!candidateTypes[i].equals(expectedTypes[i])) {
+ if
(!(expectedTypes[i].toCatalogDataType().matchesType(candidateTypes[i].toCatalogDataType())))
{
match = false;
break;
}
@@ -142,7 +143,11 @@ public enum ExpressionEvaluator {
DataType returnType =
DataType.convertFromString(annotation.returnType());
List<DataType> argTypes = new ArrayList<>();
for (String type : annotation.argTypes()) {
- argTypes.add(DataType.convertFromString(type));
+ if (type.equalsIgnoreCase("DECIMALV3")) {
+ argTypes.add(DecimalV3Type.WILDCARD);
+ } else {
+ argTypes.add(DataType.convertFromString(type));
+ }
}
FunctionSignature signature = new FunctionSignature(name,
argTypes.toArray(new DataType[argTypes.size()]),
returnType);
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalV3Literal.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalV3Literal.java
index 18c13c4ac5..c6c28fa3e9 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalV3Literal.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalV3Literal.java
@@ -39,7 +39,9 @@ public class DecimalV3Literal extends Literal {
public DecimalV3Literal(DecimalV3Type dataType, BigDecimal value) {
super(DecimalV3Type.createDecimalV3Type(dataType.getPrecision(),
dataType.getScale()));
- this.value =
Objects.requireNonNull(value.setScale(dataType.getScale(), RoundingMode.DOWN));
+ Objects.requireNonNull(value, "value not be null");
+ BigDecimal adjustedValue = value.scale() < 0 ? value :
value.setScale(dataType.getScale(), RoundingMode.DOWN);
+ this.value = Objects.requireNonNull(adjustedValue);
}
@Override
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java
index 0dbaf6d30c..65d9e8c6ac 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java
@@ -25,6 +25,7 @@ import
org.apache.doris.nereids.rules.expression.rules.InPredicateToEqualToRule;
import
org.apache.doris.nereids.rules.expression.rules.NormalizeBinaryPredicatesRule;
import org.apache.doris.nereids.rules.expression.rules.SimplifyCastRule;
import
org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate;
+import
org.apache.doris.nereids.rules.expression.rules.SimplifyDecimalV3Comparison;
import org.apache.doris.nereids.rules.expression.rules.SimplifyNotExprRule;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.EqualTo;
@@ -38,6 +39,7 @@ import
org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeV2Literal;
import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal;
import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
@@ -46,6 +48,7 @@ import
org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DecimalV2Type;
+import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.VarcharType;
@@ -283,4 +286,22 @@ public class ExpressionRewriteTest extends
ExpressionRewriteTestHelper {
new EqualTo(dv2, dv2));
}
+
+ @Test
+ public void testSimplifyDecimalV3Comparison() {
+ executor = new
ExpressionRuleExecutor(ImmutableList.of(SimplifyDecimalV3Comparison.INSTANCE));
+
+ // do rewrite
+ Expression left = new DecimalV3Literal(new BigDecimal("12345.67"));
+ Expression cast = new Cast(left, DecimalV3Type.createDecimalV3Type(27,
9));
+ Expression right = new
DecimalV3Literal(DecimalV3Type.createDecimalV3Type(27, 9), new
BigDecimal("0.01"));
+ Expression expectedRight = new
DecimalV3Literal(DecimalV3Type.createDecimalV3Type(7, 2), new
BigDecimal("0.01"));
+ Expression comparison = new EqualTo(cast, right);
+ Expression expected = new EqualTo(left, expectedRight);
+ assertRewrite(comparison, expected);
+
+ // not cast
+ comparison = new EqualTo(new DecimalV3Literal(new
BigDecimal("12345.67")), new DecimalV3Literal(new BigDecimal("76543.21")));
+ assertRewrite(comparison, comparison);
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]