This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 38a62d890fd [opt](nereids) simplify arithmethic handle with mix
add/sub/multiply/divide (#45543)
38a62d890fd is described below
commit 38a62d890fde29d15cac04ce70326baeba0813af
Author: yujun <[email protected]>
AuthorDate: Fri Jan 3 14:16:11 2025 +0800
[opt](nereids) simplify arithmethic handle with mix add/sub/multiply/divide
(#45543)
### What problem does this PR solve?
Two optimizations:
1. handle mix add / sub / multiply / divide
SimplifyArithmeticRule only handle add-sub, or multiply-divide, but not
both of them.
for example, if the expression root is add, then only simplify add-sub,
but not simplify multiply-divide.
for expr a + 10 + (b * 2 * 3 * (c + 4 + 5)) + 20, after fold const
and this rule, it will opt as a + (b * (c + 4 + 5) * 2 * 3) + 30, but
after this pr it will opt as a + (b * (c+9) * 6) + 30
2. handle cast
SimplifyArithmeticRule not handle with cast.
for example, for expr cast ( a * 2 * 30 as double) / (cast 10 as
double) , after fold const and this rule, it will opt as cast (a * 60
as double) / 10.0, but after this pr it will opt as cast (a as double)
* 6.0
---
.../expression/rules/SimplifyArithmeticRule.java | 86 +++++++++++++++-------
.../expression/SimplifyArithmeticRuleTest.java | 27 +++++--
2 files changed, 80 insertions(+), 33 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticRule.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticRule.java
index 44d6505b003..6eea495e5cf 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticRule.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyArithmeticRule.java
@@ -22,10 +22,12 @@ import
org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
import org.apache.doris.nereids.rules.expression.ExpressionRuleType;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
+import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Multiply;
import org.apache.doris.nereids.trees.expressions.Subtract;
+import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import org.apache.doris.nereids.util.TypeUtils;
import org.apache.doris.nereids.util.Utils;
@@ -35,6 +37,7 @@ import com.google.common.collect.Lists;
import java.util.List;
import java.util.Optional;
+import java.util.function.Predicate;
/**
* Simplify arithmetic rule.
@@ -91,6 +94,9 @@ public class SimplifyArithmeticRule implements
ExpressionPatternRuleFactory {
}
// 2. move variables to left side and move constants to right sid.
for (Operand operand : flattedExpressions) {
+ if (operand.expression instanceof BinaryArithmetic) {
+ operand.expression = simplify((BinaryArithmetic)
operand.expression);
+ }
if (operand.expression.isConstant()) {
constants.add(operand);
} else {
@@ -129,45 +135,73 @@ public class SimplifyArithmeticRule implements
ExpressionPatternRuleFactory {
}
}
+ // isAddOrSub: true for extract only "+" or "-" sub expressions, false for
extract only "*" or "/" sub expressions
private static List<Operand> flatten(Expression expr, boolean isAddOrSub) {
List<Operand> result = Lists.newArrayList();
- if (isAddOrSub) {
- flattenAddSubtract(true, expr, result);
- } else {
- flattenMultiplyDivide(true, expr, result);
- }
+ doFlatten(true, expr, isAddOrSub, result, Optional.empty());
return result;
}
- private static void flattenAddSubtract(boolean flag, Expression expr,
List<Operand> result) {
- if (TypeUtils.isAddOrSubtract(expr)) {
- BinaryArithmetic arithmetic = (BinaryArithmetic) expr;
- flattenAddSubtract(flag, arithmetic.left(), result);
- if (TypeUtils.isSubtract(expr) && !flag) {
- flattenAddSubtract(true, arithmetic.right(), result);
- } else if (TypeUtils.isAdd(expr) && !flag) {
- flattenAddSubtract(false, arithmetic.right(), result);
+ // flag: true for '+' or '*', false for '-' or '/'
+ // isAddOrSub: true for extract only "+" or "-" sub expressions, false for
extract only "*" or "/" sub expressions
+ private static void doFlatten(boolean flag, Expression expr, boolean
isAddOrSub, List<Operand> result,
+ Optional<DataType> castType) {
+ // cast (a * 10 as double) * (cast 20 as double)
+ // => cast(a as double) * (cast 10 as double) * (cast 20 as double)
+ BinaryArithmetic arithmetic = null;
+ Predicate<Expression> isPositiveArithmetic = isAddOrSub
+ ? TypeUtils::isAdd : TypeUtils::isMultiply;
+ Predicate<Expression> isNegativeArithmetic = isAddOrSub
+ ? TypeUtils::isSubtract : TypeUtils::isDivide;
+ Predicate<Expression> isPosNegArithmetic =
isPositiveArithmetic.or(isNegativeArithmetic);
+ if (isPosNegArithmetic.test(expr)) {
+ arithmetic = (BinaryArithmetic) expr;
+ } else if (expr instanceof Cast && hasConstantOperand(expr,
isAddOrSub)) {
+ Cast cast = (Cast) expr;
+ if (isPosNegArithmetic.test(cast.child())) {
+ arithmetic = (BinaryArithmetic) cast.child();
+ castType = Optional.of(cast.getDataType());
+ }
+ }
+ if (arithmetic != null) {
+ doFlatten(flag, arithmetic.left(), isAddOrSub, result, castType);
+ if (isNegativeArithmetic.test(arithmetic) && !flag) {
+ doFlatten(true, arithmetic.right(), isAddOrSub, result,
castType);
+ } else if (isPositiveArithmetic.test(arithmetic) && !flag) {
+ doFlatten(false, arithmetic.right(), isAddOrSub, result,
castType);
} else {
- flattenAddSubtract(!TypeUtils.isSubtract(expr),
arithmetic.right(), result);
+ doFlatten(!isNegativeArithmetic.test(arithmetic),
arithmetic.right(), isAddOrSub, result, castType);
}
} else {
- result.add(Operand.of(flag, expr));
+ if (castType.isPresent()) {
+ result.add(Operand.of(flag,
TypeCoercionUtils.castIfNotSameType(expr, castType.get())));
+ } else {
+ result.add(Operand.of(flag, expr));
+ }
}
}
- private static void flattenMultiplyDivide(boolean flag, Expression expr,
List<Operand> result) {
- if (TypeUtils.isMultiplyOrDivide(expr)) {
- BinaryArithmetic arithmetic = (BinaryArithmetic) expr;
- flattenMultiplyDivide(flag, arithmetic.left(), result);
- if (TypeUtils.isDivide(expr) && !flag) {
- flattenMultiplyDivide(true, arithmetic.right(), result);
- } else if (TypeUtils.isMultiply(expr) && !flag) {
- flattenMultiplyDivide(false, arithmetic.right(), result);
- } else {
- flattenMultiplyDivide(!TypeUtils.isDivide(expr),
arithmetic.right(), result);
+ private static boolean hasConstantOperand(Expression expr, boolean
isAddOrSub) {
+ if (expr.isConstant()) {
+ return true;
+ }
+
+ Predicate<Expression> checkArithmetic = isAddOrSub
+ ? TypeUtils::isAddOrSubtract : TypeUtils::isMultiplyOrDivide;
+ BinaryArithmetic arithmetic = null;
+ if (checkArithmetic.test(expr)) {
+ arithmetic = (BinaryArithmetic) expr;
+ } else if (expr instanceof Cast) {
+ Cast cast = (Cast) expr;
+ if (checkArithmetic.test(cast.child())) {
+ arithmetic = (BinaryArithmetic) cast.child();
}
+ }
+ if (arithmetic != null) {
+ return hasConstantOperand(arithmetic.left(), isAddOrSub)
+ || hasConstantOperand(arithmetic.right(), isAddOrSub);
} else {
- result.add(Operand.of(flag, expr));
+ return false;
}
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyArithmeticRuleTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyArithmeticRuleTest.java
index f23aefe5267..92eb90e93b0 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyArithmeticRuleTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyArithmeticRuleTest.java
@@ -46,12 +46,19 @@ class SimplifyArithmeticRuleTest extends
ExpressionRewriteTestHelper {
assertRewriteAfterTypeCoercion("IA + 2 - ((1 - IB) - (3 + IC))", "IA +
IB + IC + 4");
assertRewriteAfterTypeCoercion("IA * IB + 2 - IC * 2", "(IA * IB) -
(IC * 2) + 2");
assertRewriteAfterTypeCoercion("IA * IB", "IA * IB");
+
assertRewriteAfterTypeCoercion("IA * IB / 2 * 2", "cast((IA * IB) as
DOUBLE) / 1.0");
assertRewriteAfterTypeCoercion("IA * IB / (2 * 2)", "cast((IA * IB) as
DOUBLE) / 4.0");
assertRewriteAfterTypeCoercion("IA * IB / (2 * 2)", "cast((IA * IB) as
DOUBLE) / 4.0");
assertRewriteAfterTypeCoercion("IA * (IB / 2) * 2)", "cast(IA as
DOUBLE) * cast(IB as DOUBLE) / 1.0");
assertRewriteAfterTypeCoercion("IA * (IB / 2) * (IC + 1))", "cast(IA
as DOUBLE) * cast(IB as DOUBLE) * cast((IC + 1) as DOUBLE) / 2.0");
assertRewriteAfterTypeCoercion("IA * IB / 2 / IC * 2 * ID / 4",
"(((cast((IA * IB) as DOUBLE) / cast(IC as DOUBLE)) * cast(ID as DOUBLE)) /
4.0)");
+ assertRewriteAfterTypeCoercion("-1 + (10 - 20) * (3 - 6) - (100 -
200) * (6 - 3)", "329");
+ assertRewriteAfterTypeCoercion("IA - 10 + (IB * 2 * 3) + 20", "IA +
(IB * 6) - (-10)");
+ assertRewriteAfterTypeCoercion("IA / 10 * (IB - 2 + 3) * 20",
"((cast(IA as DOUBLE) * cast((IB - (-1)) as DOUBLE)) / 0.5)");
+ assertRewriteAfterTypeCoercion("1 + ((IA * 2 * 3) * 10 / 10)",
"((cast(IA as DOUBLE) * 6.0) + 1.0)");
+ assertRewriteAfterTypeCoercion("1 + (IA * 2 * 20 / (IB + 5 + (IC * 10
* 20 / 50 + 5 + 6) + 20) / 20) * (ID * 5 * 6 / (IE + 20 + 30)) + 200",
+ "(((((cast(IA as DOUBLE) / ((cast(IB as DOUBLE) + (cast(IC as
DOUBLE) * 4.0)) + 36.0)) * cast(ID as DOUBLE)) / cast((IE + 50) as DOUBLE)) *
60.0) + 201.0)");
}
@Test
@@ -69,18 +76,24 @@ class SimplifyArithmeticRuleTest extends
ExpressionRewriteTestHelper {
assertRewriteAfterTypeCoercion("IA - 2 - ((-IB - 1) - (3 + (IC +
4)))", "(((IA + IB) + IC) - ((((2 + 0) - 1) - 3) - 4))");
// multiply and divide
- assertRewriteAfterTypeCoercion("2 / IA / ((1 / IB) / (3 * IC))",
"((((cast(2 as DOUBLE) / cast(1 as DOUBLE)) / cast(IA as DOUBLE)) * cast(IB as
DOUBLE)) * cast((IC * 3) as DOUBLE))");
- assertRewriteAfterTypeCoercion("IA / 2 / ((IB * 1) / (3 / (IC / 4)))",
"(((cast(IA as DOUBLE) / cast((IB * 1) as DOUBLE)) / cast(IC as DOUBLE)) /
((cast(2 as DOUBLE) / cast(3 as DOUBLE)) / cast(4 as DOUBLE)))");
- assertRewriteAfterTypeCoercion("IA / 2 / ((IB / 1) / (3 / (IC * 4)))",
"(((cast(IA as DOUBLE) / cast(IB as DOUBLE)) / cast((IC * 4) as DOUBLE)) /
((cast(2 as DOUBLE) / cast(1 as DOUBLE)) / cast(3 as DOUBLE)))");
- assertRewriteAfterTypeCoercion("IA / 2 / ((IB / 1) / (3 * (IC * 4)))",
"(((cast(IA as DOUBLE) / cast(IB as DOUBLE)) * cast((IC * (3 * 4)) as DOUBLE))
/ (cast(2 as DOUBLE) / cast(1 as DOUBLE)))");
+ assertRewriteAfterTypeCoercion("2 / IA / ((1 / IB) / (3 * IC))",
+ "(((((cast(2 as DOUBLE) / cast(1 as DOUBLE)) * cast (3 as
DOUBLE)) / cast(IA as DOUBLE)) * cast(IB as DOUBLE)) * cast(IC as DOUBLE))");
+ assertRewriteAfterTypeCoercion("IA / 2 / ((IB * 1) / (3 / (IC / 4)))",
+ "(((cast(IA as DOUBLE) / cast(IB as DOUBLE)) / cast(IC as
DOUBLE)) / (((cast(2 as DOUBLE) * cast(1 as DOUBLE)) / cast(3 as DOUBLE)) /
cast(4 as DOUBLE)))");
+ assertRewriteAfterTypeCoercion("IA / 2 / ((IB / 1) / (3 / (IC * 4)))",
+ "(((cast(IA as DOUBLE) / cast(IB as DOUBLE)) / cast(IC as
DOUBLE)) / (((cast(2 as DOUBLE) / cast(1 as DOUBLE)) / cast(3 as DOUBLE)) *
cast(4 as DOUBLE)))");
+ assertRewriteAfterTypeCoercion("IA / 2 / ((IB / 1) / (3 * (IC * 4)))",
+ "(((cast(IA as DOUBLE) / cast(IB as DOUBLE)) * cast(IC as
DOUBLE)) / (((cast(2 as DOUBLE) / cast(1 as DOUBLE)) / cast(3 as DOUBLE)) /
cast(4 as DOUBLE)))");
// hybrid
// root is subtract
- assertRewriteAfterTypeCoercion("-2 - IA * ((1 - IB) - (3 / IC))",
"(cast(-2 as DOUBLE) - (cast(IA as DOUBLE) * (cast((1 - IB) as DOUBLE) -
(cast(3 as DOUBLE) / cast(IC as DOUBLE)))))");
- assertRewriteAfterTypeCoercion("-IA - 2 - ((IB * 1) - (3 * (IC /
4)))", "((cast(((0 - 2) - IA) as DOUBLE) - cast((IB * 1) as DOUBLE)) + (cast(3
as DOUBLE) * (cast(IC as DOUBLE) / cast(4 as DOUBLE))))");
+ assertRewriteAfterTypeCoercion("-2 - IA * ((1 - IB) - (3 / IC))",
+ "(cast(-2 as DOUBLE) - (cast(IA as DOUBLE) * ((cast(1 as
DOUBLE) - cast(IB as DOUBLE)) - (cast(3 as DOUBLE) / cast(IC as DOUBLE)))))");
+ assertRewriteAfterTypeCoercion("-IA - 2 - ((IB * 1) - (3 * (IC / 4)))",
+ "((((cast(0 as DOUBLE) - cast(2 as DOUBLE)) - cast(IA as
DOUBLE)) - cast((IB * 1) as DOUBLE)) + (cast(IC as DOUBLE) * (cast(3 as DOUBLE)
/ cast(4 as DOUBLE))))");
// root is add
assertRewriteAfterTypeCoercion("-IA * 2 + ((IB - 1) / (3 - (IC +
4)))", "(cast(((0 - IA) * 2) as DOUBLE) + (cast((IB - 1) as DOUBLE) / cast(((3
- 4) - IC) as DOUBLE)))");
- assertRewriteAfterTypeCoercion("-IA + 2 + ((IB - 1) - (3 * (IC +
4)))", "(((((0 + 2) - 1) - IA) + IB) - (3 * (IC + 4)))");
+ assertRewriteAfterTypeCoercion("-IA + 2 + ((IB - 1) - (3 * (IC +
4)))", "(((((0 + 2) - 1) - IA) + IB) - ((IC + 4) * 3))");
// root is multiply
assertRewriteAfterTypeCoercion("-IA / 2 * ((-IB - 1) - (3 + (IC +
4)))", "((cast((0 - IA) as DOUBLE) * cast((((((0 - 1) - 3) - 4) - IB) - IC) as
DOUBLE)) / cast(2 as DOUBLE))");
assertRewriteAfterTypeCoercion("-IA / 2 * ((-IB - 1) * (3 / (IC +
4)))", "(((cast((0 - IA) as DOUBLE) * cast(((0 - 1) - IB) as DOUBLE)) /
cast((IC + 4) as DOUBLE)) / (cast(2 as DOUBLE) / cast(3 as DOUBLE)))");
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]