This is an automated email from the ASF dual-hosted git repository.
morrySnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new b5c0bcd3434 [fix](simplify agg) SimplifyAggGroupBy should verify
injectivity (#64335)
b5c0bcd3434 is described below
commit b5c0bcd3434f8823de33ff0c94d246a5f074ccbf
Author: yujun <[email protected]>
AuthorDate: Wed Jul 1 17:54:27 2026 +0800
[fix](simplify agg) SimplifyAggGroupBy should verify injectivity (#64335)
## Problem
`SimplifyAggGroupBy` simplified `GROUP BY f(x)` to `GROUP BY x` without
verifying that `f(x)` is injective (one-to-one). This caused wrong
results:
| Expression | Why wrong |
|---|---|
| `a * 0` / `0 * a` | always evaluates to 0 — all rows fall into one
group |
| `0 / a` | always evaluates to 0 |
| `a / 0` | division by zero |
| `a + NULL` / `a * NULL` / ... | always evaluates to NULL |
| `a * 0.1` with float/double | precision loss may map different inputs
to same result |
## Fix
1. **`isBinaryArithmeticSlot`**: restructured to separate slot-expr from
literal,
then validate each independently. Float/double check runs early, before
slot extraction.
2. **New `checkLiteral(expr, literal)`**: rejects NULL literal and
Multiply/Divide by zero.
3. **New `canExtractSlot(expr)`**: replaces the old unconditional
`extractSlotOrCastOnSlot` — only accepts bare `Slot` or implicit
lossless
widening casts (integral→integral, float→double, integral→decimal,
decimal→decimal). Range and scale are compared directly for correctness.
## Changes
- `SimplifyAggGroupBy.java`: +80 lines, rewritten core logic
- `ExpressionUtils.java`: -35 lines, removed unused `isSlotOrCastOnSlot`
/
`extractSlotOrCastOnSlot`
- `SimplifyAggGroupByTest.java`: +216 lines, 25 tests covering all new
paths
---------
Co-authored-by: Claude Opus 4.7 <[email protected]>
---
.../nereids/rules/rewrite/SimplifyAggGroupBy.java | 57 ++++-
.../apache/doris/nereids/util/ExpressionUtils.java | 36 ----
.../rules/rewrite/SimplifyAggGroupByTest.java | 240 +++++++++++++++++++++
3 files changed, 293 insertions(+), 40 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java
index 37d4d4806f0..f45a6da8745 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java
@@ -19,9 +19,9 @@ package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
-import org.apache.doris.nereids.trees.TreeNode;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
+import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Multiply;
@@ -71,7 +71,7 @@ public class SimplifyAggGroupBy extends OneRewriteRuleFactory
{
}
@VisibleForTesting
- protected static boolean isBinaryArithmeticSlot(TreeNode<Expression> expr)
{
+ protected static boolean isBinaryArithmeticSlot(Expression expr) {
if (expr instanceof Slot) {
return true;
}
@@ -81,7 +81,56 @@ public class SimplifyAggGroupBy extends
OneRewriteRuleFactory {
if (!supportedFunctions.contains(expr.getClass())) {
return false;
}
- return ExpressionUtils.isSlotOrCastOnSlot(expr.child(0)).isPresent()
&& expr.child(1) instanceof Literal
- ||
ExpressionUtils.isSlotOrCastOnSlot(expr.child(1)).isPresent() && expr.child(0)
instanceof Literal;
+
+ // Float/double arithmetic: precision loss for all operations
+ if (expr.child(0).getDataType().isFloatLikeType()
+ || expr.child(1).getDataType().isFloatLikeType()) {
+ return false;
+ }
+
+ Expression slotExpr;
+ Literal literal;
+ if (expr.child(0) instanceof Literal) {
+ literal = (Literal) expr.child(0);
+ slotExpr = expr.child(1);
+ } else if (expr.child(1) instanceof Literal) {
+ literal = (Literal) expr.child(1);
+ slotExpr = expr.child(0);
+ } else {
+ return false;
+ }
+
+ if (!canExtractSlot(slotExpr)) {
+ return false;
+ }
+
+ return checkLiteral((BinaryArithmetic) expr, literal);
}
+
+ @VisibleForTesting
+ protected static boolean checkLiteral(BinaryArithmetic expr, Literal
literal) {
+ if (literal.isNullLiteral()) {
+ return false;
+ }
+ if (expr instanceof Multiply || expr instanceof Divide) {
+ if (literal.isZero()) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ @VisibleForTesting
+ protected static boolean canExtractSlot(Expression expr) {
+ while (expr instanceof Cast) {
+ Cast cast = (Cast) expr;
+ Expression inner = cast.child();
+ if (!inner.getDataType().isInjectiveCastTo(cast.getDataType())) {
+ return false;
+ }
+ expr = inner;
+ }
+ return expr instanceof Slot;
+ }
+
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
index 7fa7acfa869..6d318cafd9c 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
@@ -42,7 +42,6 @@ import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
-import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.IsNull;
@@ -421,41 +420,6 @@ public class ExpressionUtils {
return minSlot;
}
- /**
- * Check whether the input expression is a {@link
org.apache.doris.nereids.trees.expressions.Slot}
- * or at least one {@link Cast} on a {@link
org.apache.doris.nereids.trees.expressions.Slot}
- * <p>
- * for example:
- * - SlotReference to a column:
- * col
- * - Cast on SlotReference:
- * cast(int_col as string)
- * cast(cast(int_col as long) as string)
- *
- * @param expr input expression
- * @return Return Optional[ExprId] of underlying slot reference if input
expression is a slot or cast on slot.
- * Otherwise, return empty optional result.
- */
- public static Optional<ExprId> isSlotOrCastOnSlot(Expression expr) {
- return extractSlotOrCastOnSlot(expr).map(Slot::getExprId);
- }
-
- /**
- * Check whether the input expression is a {@link
org.apache.doris.nereids.trees.expressions.Slot}
- * or at least one {@link Cast} on a {@link
org.apache.doris.nereids.trees.expressions.Slot}
- */
- public static Optional<Slot> extractSlotOrCastOnSlot(Expression expr) {
- while (expr instanceof Cast) {
- expr = expr.child(0);
- }
-
- if (expr instanceof SlotReference) {
- return Optional.of((Slot) expr);
- } else {
- return Optional.empty();
- }
- }
-
/**
* Generate replaceMap Slot -> Expression from NamedExpression[Expression
as name]
*/
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupByTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupByTest.java
index 32c2cc4356d..f6a822d5b65 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupByTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupByTest.java
@@ -18,18 +18,30 @@
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.trees.expressions.Add;
+import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Mod;
import org.apache.doris.nereids.trees.expressions.Multiply;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.Subtract;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Abs;
+import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.FloatLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.types.BigIntType;
+import org.apache.doris.nereids.types.DecimalV3Type;
+import org.apache.doris.nereids.types.DoubleType;
+import org.apache.doris.nereids.types.FloatType;
+import org.apache.doris.nereids.types.IntegerType;
+import org.apache.doris.nereids.types.TinyIntType;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
@@ -41,6 +53,7 @@ import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
+import java.math.BigDecimal;
import java.util.List;
class SimplifyAggGroupByTest implements MemoPatternMatchSupported {
@@ -156,4 +169,231 @@ class SimplifyAggGroupByTest implements
MemoPatternMatchSupported {
Divide divide = new Divide(id, Literal.of(2));
Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(divide));
}
+
+ // ========== new tests for injectivity checks ==========
+
+ @Test
+ void testMultiplyByZero() {
+ Slot id = scan1.getOutput().get(0);
+ Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+ new Multiply(id, Literal.of(0))));
+ Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+ new Multiply(Literal.of(0), id)));
+ }
+
+ @Test
+ void testDivideZeroNumerator() {
+ Slot id = scan1.getOutput().get(0);
+ Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+ new Divide(Literal.of(0), id)));
+ }
+
+ @Test
+ void testDivideByZero() {
+ Slot id = scan1.getOutput().get(0);
+ Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+ new Divide(id, Literal.of(0))));
+ }
+
+ @Test
+ void testNullLiteral() {
+ Slot id = scan1.getOutput().get(0);
+ Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+ new Add(id, NullLiteral.INSTANCE)));
+ Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+ new Multiply(id, NullLiteral.INSTANCE)));
+ }
+
+ @Test
+ void testMultiplyWithDoubleLiteral() {
+ Slot id = scan1.getOutput().get(0);
+ Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+ new Multiply(id, new DoubleLiteral(0.1))));
+ }
+
+ @Test
+ void testDivideWithDoubleLiteral() {
+ Slot id = scan1.getOutput().get(0);
+ Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+ new Divide(id, new DoubleLiteral(2.0))));
+ }
+
+ @Test
+ void testMultiplyWithFloatSlot() {
+ Slot floatSlot = new SlotReference("f", FloatType.INSTANCE);
+ Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+ new Multiply(floatSlot, Literal.of(2))));
+ }
+
+ @Test
+ void testMultiplyDoubleSlotWithIntLiteral() {
+ Slot doubleSlot = new SlotReference("d", DoubleType.INSTANCE);
+ Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+ new Multiply(doubleSlot, Literal.of(2))));
+ }
+
+ @Test
+ void testAddWithDoubleLiteral() {
+ // Float/double arithmetic may be imprecise, reject for all ops
+ Slot id = scan1.getOutput().get(0);
+ Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+ new Add(id, new DoubleLiteral(1.0))));
+ }
+
+ @Test
+ void testAddWithFloatLiteral() {
+ Slot id = scan1.getOutput().get(0);
+ Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+ new Add(id, new FloatLiteral(1.0f))));
+ }
+
+ @Test
+ void testSubtractWithDoubleLiteral() {
+ Slot id = scan1.getOutput().get(0);
+ Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+ new Subtract(id, new DoubleLiteral(1.0))));
+ }
+
+ @Test
+ void testMultiplyWithDecimalLiteral() {
+ // Small decimal multiply should pass (precision fits)
+ Slot id = scan1.getOutput().get(0);
+ Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+ new Multiply(id, new DecimalLiteral(new BigDecimal("2.0")))));
+ }
+
+ @Test
+ void testDivideWithDecimalLiteral() {
+ // Divide with decimal: precision overflow too extreme to worry about
+ Slot id = scan1.getOutput().get(0);
+ Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+ new Divide(id, new DecimalLiteral(new BigDecimal("2.0")))));
+ }
+
+ @Test
+ void testAddWithDecimalLiteral() {
+ // Add/Subtract with decimal are exact, should pass
+ Slot id = scan1.getOutput().get(0);
+ Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+ new Add(id, new DecimalLiteral(new BigDecimal("1.0")))));
+ }
+
+ // ========== tests for isInjectiveCastTo ==========
+
+ @Test
+ void testIntegerWidening() {
+
Assertions.assertTrue(TinyIntType.INSTANCE.isInjectiveCastTo(IntegerType.INSTANCE));
+
Assertions.assertTrue(IntegerType.INSTANCE.isInjectiveCastTo(BigIntType.INSTANCE));
+
Assertions.assertFalse(IntegerType.INSTANCE.isInjectiveCastTo(TinyIntType.INSTANCE));
+
Assertions.assertFalse(BigIntType.INSTANCE.isInjectiveCastTo(IntegerType.INSTANCE));
+ }
+
+ @Test
+ void testDecimalWidening() {
+ Assertions.assertTrue(DecimalV3Type.createDecimalV3Type(5, 2)
+ .isInjectiveCastTo(DecimalV3Type.createDecimalV3Type(10, 4)));
+ Assertions.assertFalse(DecimalV3Type.createDecimalV3Type(10, 4)
+ .isInjectiveCastTo(DecimalV3Type.createDecimalV3Type(5, 2)));
+ }
+
+ @Test
+ void testIntegralToDecimalWidening() {
+ Assertions.assertTrue(TinyIntType.INSTANCE
+ .isInjectiveCastTo(DecimalV3Type.createDecimalV3Type(10, 0)));
+ // BigInt has 19 digits, DECIMAL(5,0) only has 5 integer digits
+ Assertions.assertFalse(BigIntType.INSTANCE
+ .isInjectiveCastTo(DecimalV3Type.createDecimalV3Type(5, 0)));
+ }
+
+ @Test
+ void testCrossFamilyRejected() {
+
Assertions.assertFalse(IntegerType.INSTANCE.isInjectiveCastTo(FloatType.INSTANCE));
+
Assertions.assertFalse(FloatType.INSTANCE.isInjectiveCastTo(IntegerType.INSTANCE));
+
Assertions.assertFalse(IntegerType.INSTANCE.isInjectiveCastTo(DoubleType.INSTANCE));
+ }
+
+ // ========== tests for canExtractSlot ==========
+
+ @Test
+ void testCanExtractSlotBare() {
+ Slot id = scan1.getOutput().get(0);
+ Assertions.assertTrue(SimplifyAggGroupBy.canExtractSlot(id));
+ }
+
+ @Test
+ void testCanExtractSlotWidening() {
+ Slot id = scan1.getOutput().get(0);
+ // INT->BIGINT is lossless widening
+ Expression cast = new Cast(id, BigIntType.INSTANCE);
+ Assertions.assertTrue(SimplifyAggGroupBy.canExtractSlot(cast));
+ }
+
+ @Test
+ void testCanExtractSlotExplicitCast() {
+ Slot id = scan1.getOutput().get(0);
+ // explicit cast should also be acceptable if lossless
+ Expression cast = new Cast(id, BigIntType.INSTANCE, true);
+ Assertions.assertTrue(SimplifyAggGroupBy.canExtractSlot(cast));
+ }
+
+ @Test
+ void testCanExtractSlotNarrowing() {
+ Slot id = scan1.getOutput().get(0);
+ // INT -> TINYINT is narrowing, should be rejected
+ Expression cast = new Cast(id, TinyIntType.INSTANCE);
+ Assertions.assertFalse(SimplifyAggGroupBy.canExtractSlot(cast));
+ }
+
+ // ========== integration tests via PlanChecker ==========
+
+ @Test
+ void testMultiplyByZeroNotSimplified() {
+ Slot id = scan1.getOutput().get(0);
+ List<NamedExpression> output = ImmutableList.of(id, new
Count().alias("cnt"));
+ List<Expression> groupBy = ImmutableList.of(id, new Multiply(id,
Literal.of(0)));
+ LogicalPlan agg = new LogicalPlanBuilder(scan1)
+ .agg(groupBy, output)
+ .build();
+ ConnectContext connectContext = MemoTestUtils.createConnectContext();
+
connectContext.getSessionVariable().setEnableMaterializedViewRewrite(false);
+ PlanChecker.from(connectContext, agg)
+ .applyTopDown(new SimplifyAggGroupBy())
+ .matchesFromRoot(
+ logicalAggregate().when(a -> a.equals(agg))
+ );
+ }
+
+ @Test
+ void testNullLiteralNotSimplified() {
+ Slot id = scan1.getOutput().get(0);
+ List<NamedExpression> output = ImmutableList.of(id, new
Count().alias("cnt"));
+ List<Expression> groupBy = ImmutableList.of(id, new Add(id,
NullLiteral.INSTANCE));
+ LogicalPlan agg = new LogicalPlanBuilder(scan1)
+ .agg(groupBy, output)
+ .build();
+ ConnectContext connectContext = MemoTestUtils.createConnectContext();
+
connectContext.getSessionVariable().setEnableMaterializedViewRewrite(false);
+ PlanChecker.from(connectContext, agg)
+ .applyTopDown(new SimplifyAggGroupBy())
+ .matchesFromRoot(
+ logicalAggregate().when(a -> a.equals(agg))
+ );
+ }
+
+ @Test
+ void testMultiplyDoubleLiteralNotSimplified() {
+ Slot id = scan1.getOutput().get(0);
+ List<NamedExpression> output = ImmutableList.of(id, new
Count().alias("cnt"));
+ List<Expression> groupBy = ImmutableList.of(id, new Multiply(id, new
DoubleLiteral(0.1)));
+ LogicalPlan agg = new LogicalPlanBuilder(scan1)
+ .agg(groupBy, output)
+ .build();
+ ConnectContext connectContext = MemoTestUtils.createConnectContext();
+
connectContext.getSessionVariable().setEnableMaterializedViewRewrite(false);
+ PlanChecker.from(connectContext, agg)
+ .applyTopDown(new SimplifyAggGroupBy())
+ .matchesFromRoot(
+ logicalAggregate().when(a -> a.equals(agg))
+ );
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]