This is an automated email from the ASF dual-hosted git repository.

morrySnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new b5c0bcd3434 [fix](simplify agg) SimplifyAggGroupBy should verify 
injectivity (#64335)
b5c0bcd3434 is described below

commit b5c0bcd3434f8823de33ff0c94d246a5f074ccbf
Author: yujun <[email protected]>
AuthorDate: Wed Jul 1 17:54:27 2026 +0800

    [fix](simplify agg) SimplifyAggGroupBy should verify injectivity (#64335)
    
    ## Problem
    
    `SimplifyAggGroupBy` simplified `GROUP BY f(x)` to `GROUP BY x` without
    verifying that `f(x)` is injective (one-to-one). This caused wrong
    results:
    
    | Expression | Why wrong |
    |---|---|
    | `a * 0` / `0 * a` | always evaluates to 0 — all rows fall into one
    group |
    | `0 / a` | always evaluates to 0 |
    | `a / 0` | division by zero |
    | `a + NULL` / `a * NULL` / ... | always evaluates to NULL |
    | `a * 0.1` with float/double | precision loss may map different inputs
    to same result |
    
    ## Fix
    
    1. **`isBinaryArithmeticSlot`**: restructured to separate slot-expr from
    literal,
    then validate each independently. Float/double check runs early, before
       slot extraction.
    
    2. **New `checkLiteral(expr, literal)`**: rejects NULL literal and
       Multiply/Divide by zero.
    
    3. **New `canExtractSlot(expr)`**: replaces the old unconditional
    `extractSlotOrCastOnSlot` — only accepts bare `Slot` or implicit
    lossless
       widening casts (integral→integral, float→double, integral→decimal,
    decimal→decimal). Range and scale are compared directly for correctness.
    
    ## Changes
    
    - `SimplifyAggGroupBy.java`: +80 lines, rewritten core logic
    - `ExpressionUtils.java`: -35 lines, removed unused `isSlotOrCastOnSlot`
    /
      `extractSlotOrCastOnSlot`
    - `SimplifyAggGroupByTest.java`: +216 lines, 25 tests covering all new
    paths
    
    ---------
    
    Co-authored-by: Claude Opus 4.7 <[email protected]>
---
 .../nereids/rules/rewrite/SimplifyAggGroupBy.java  |  57 ++++-
 .../apache/doris/nereids/util/ExpressionUtils.java |  36 ----
 .../rules/rewrite/SimplifyAggGroupByTest.java      | 240 +++++++++++++++++++++
 3 files changed, 293 insertions(+), 40 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java
index 37d4d4806f0..f45a6da8745 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java
@@ -19,9 +19,9 @@ package org.apache.doris.nereids.rules.rewrite;
 
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
-import org.apache.doris.nereids.trees.TreeNode;
 import org.apache.doris.nereids.trees.expressions.Add;
 import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
+import org.apache.doris.nereids.trees.expressions.Cast;
 import org.apache.doris.nereids.trees.expressions.Divide;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.Multiply;
@@ -71,7 +71,7 @@ public class SimplifyAggGroupBy extends OneRewriteRuleFactory 
{
     }
 
     @VisibleForTesting
-    protected static boolean isBinaryArithmeticSlot(TreeNode<Expression> expr) 
{
+    protected static boolean isBinaryArithmeticSlot(Expression expr) {
         if (expr instanceof Slot) {
             return true;
         }
@@ -81,7 +81,56 @@ public class SimplifyAggGroupBy extends 
OneRewriteRuleFactory {
         if (!supportedFunctions.contains(expr.getClass())) {
             return false;
         }
-        return ExpressionUtils.isSlotOrCastOnSlot(expr.child(0)).isPresent() 
&& expr.child(1) instanceof Literal
-                || 
ExpressionUtils.isSlotOrCastOnSlot(expr.child(1)).isPresent() && expr.child(0) 
instanceof Literal;
+
+        // Float/double arithmetic: precision loss for all operations
+        if (expr.child(0).getDataType().isFloatLikeType()
+                || expr.child(1).getDataType().isFloatLikeType()) {
+            return false;
+        }
+
+        Expression slotExpr;
+        Literal literal;
+        if (expr.child(0) instanceof Literal) {
+            literal = (Literal) expr.child(0);
+            slotExpr = expr.child(1);
+        } else if (expr.child(1) instanceof Literal) {
+            literal = (Literal) expr.child(1);
+            slotExpr = expr.child(0);
+        } else {
+            return false;
+        }
+
+        if (!canExtractSlot(slotExpr)) {
+            return false;
+        }
+
+        return checkLiteral((BinaryArithmetic) expr, literal);
     }
+
+    @VisibleForTesting
+    protected static boolean checkLiteral(BinaryArithmetic expr, Literal 
literal) {
+        if (literal.isNullLiteral()) {
+            return false;
+        }
+        if (expr instanceof Multiply || expr instanceof Divide) {
+            if (literal.isZero()) {
+                return false;
+            }
+        }
+        return true;
+    }
+
+    @VisibleForTesting
+    protected static boolean canExtractSlot(Expression expr) {
+        while (expr instanceof Cast) {
+            Cast cast = (Cast) expr;
+            Expression inner = cast.child();
+            if (!inner.getDataType().isInjectiveCastTo(cast.getDataType())) {
+                return false;
+            }
+            expr = inner;
+        }
+        return expr instanceof Slot;
+    }
+
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
index 7fa7acfa869..6d318cafd9c 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
@@ -42,7 +42,6 @@ import org.apache.doris.nereids.trees.expressions.Cast;
 import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
 import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
 import org.apache.doris.nereids.trees.expressions.EqualTo;
-import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.InPredicate;
 import org.apache.doris.nereids.trees.expressions.IsNull;
@@ -421,41 +420,6 @@ public class ExpressionUtils {
         return minSlot;
     }
 
-    /**
-     * Check whether the input expression is a {@link 
org.apache.doris.nereids.trees.expressions.Slot}
-     * or at least one {@link Cast} on a {@link 
org.apache.doris.nereids.trees.expressions.Slot}
-     * <p>
-     * for example:
-     * - SlotReference to a column:
-     * col
-     * - Cast on SlotReference:
-     * cast(int_col as string)
-     * cast(cast(int_col as long) as string)
-     *
-     * @param expr input expression
-     * @return Return Optional[ExprId] of underlying slot reference if input 
expression is a slot or cast on slot.
-     *         Otherwise, return empty optional result.
-     */
-    public static Optional<ExprId> isSlotOrCastOnSlot(Expression expr) {
-        return extractSlotOrCastOnSlot(expr).map(Slot::getExprId);
-    }
-
-    /**
-     * Check whether the input expression is a {@link 
org.apache.doris.nereids.trees.expressions.Slot}
-     * or at least one {@link Cast} on a {@link 
org.apache.doris.nereids.trees.expressions.Slot}
-     */
-    public static Optional<Slot> extractSlotOrCastOnSlot(Expression expr) {
-        while (expr instanceof Cast) {
-            expr = expr.child(0);
-        }
-
-        if (expr instanceof SlotReference) {
-            return Optional.of((Slot) expr);
-        } else {
-            return Optional.empty();
-        }
-    }
-
     /**
      * Generate replaceMap Slot -> Expression from NamedExpression[Expression 
as name]
      */
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupByTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupByTest.java
index 32c2cc4356d..f6a822d5b65 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupByTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupByTest.java
@@ -18,18 +18,30 @@
 package org.apache.doris.nereids.rules.rewrite;
 
 import org.apache.doris.nereids.trees.expressions.Add;
+import org.apache.doris.nereids.trees.expressions.Cast;
 import org.apache.doris.nereids.trees.expressions.Divide;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.Mod;
 import org.apache.doris.nereids.trees.expressions.Multiply;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.Subtract;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
 import org.apache.doris.nereids.trees.expressions.functions.scalar.Abs;
+import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.FloatLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.types.BigIntType;
+import org.apache.doris.nereids.types.DecimalV3Type;
+import org.apache.doris.nereids.types.DoubleType;
+import org.apache.doris.nereids.types.FloatType;
+import org.apache.doris.nereids.types.IntegerType;
+import org.apache.doris.nereids.types.TinyIntType;
 import org.apache.doris.nereids.util.LogicalPlanBuilder;
 import org.apache.doris.nereids.util.MemoPatternMatchSupported;
 import org.apache.doris.nereids.util.MemoTestUtils;
@@ -41,6 +53,7 @@ import com.google.common.collect.ImmutableList;
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
+import java.math.BigDecimal;
 import java.util.List;
 
 class SimplifyAggGroupByTest implements MemoPatternMatchSupported {
@@ -156,4 +169,231 @@ class SimplifyAggGroupByTest implements 
MemoPatternMatchSupported {
         Divide divide = new Divide(id, Literal.of(2));
         
Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(divide));
     }
+
+    // ========== new tests for injectivity checks ==========
+
+    @Test
+    void testMultiplyByZero() {
+        Slot id = scan1.getOutput().get(0);
+        Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Multiply(id, Literal.of(0))));
+        Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Multiply(Literal.of(0), id)));
+    }
+
+    @Test
+    void testDivideZeroNumerator() {
+        Slot id = scan1.getOutput().get(0);
+        Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Divide(Literal.of(0), id)));
+    }
+
+    @Test
+    void testDivideByZero() {
+        Slot id = scan1.getOutput().get(0);
+        Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Divide(id, Literal.of(0))));
+    }
+
+    @Test
+    void testNullLiteral() {
+        Slot id = scan1.getOutput().get(0);
+        Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Add(id, NullLiteral.INSTANCE)));
+        Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Multiply(id, NullLiteral.INSTANCE)));
+    }
+
+    @Test
+    void testMultiplyWithDoubleLiteral() {
+        Slot id = scan1.getOutput().get(0);
+        Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Multiply(id, new DoubleLiteral(0.1))));
+    }
+
+    @Test
+    void testDivideWithDoubleLiteral() {
+        Slot id = scan1.getOutput().get(0);
+        Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Divide(id, new DoubleLiteral(2.0))));
+    }
+
+    @Test
+    void testMultiplyWithFloatSlot() {
+        Slot floatSlot = new SlotReference("f", FloatType.INSTANCE);
+        Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Multiply(floatSlot, Literal.of(2))));
+    }
+
+    @Test
+    void testMultiplyDoubleSlotWithIntLiteral() {
+        Slot doubleSlot = new SlotReference("d", DoubleType.INSTANCE);
+        Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Multiply(doubleSlot, Literal.of(2))));
+    }
+
+    @Test
+    void testAddWithDoubleLiteral() {
+        // Float/double arithmetic may be imprecise, reject for all ops
+        Slot id = scan1.getOutput().get(0);
+        Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Add(id, new DoubleLiteral(1.0))));
+    }
+
+    @Test
+    void testAddWithFloatLiteral() {
+        Slot id = scan1.getOutput().get(0);
+        Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Add(id, new FloatLiteral(1.0f))));
+    }
+
+    @Test
+    void testSubtractWithDoubleLiteral() {
+        Slot id = scan1.getOutput().get(0);
+        Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Subtract(id, new DoubleLiteral(1.0))));
+    }
+
+    @Test
+    void testMultiplyWithDecimalLiteral() {
+        // Small decimal multiply should pass (precision fits)
+        Slot id = scan1.getOutput().get(0);
+        Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Multiply(id, new DecimalLiteral(new BigDecimal("2.0")))));
+    }
+
+    @Test
+    void testDivideWithDecimalLiteral() {
+        // Divide with decimal: precision overflow too extreme to worry about
+        Slot id = scan1.getOutput().get(0);
+        Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Divide(id, new DecimalLiteral(new BigDecimal("2.0")))));
+    }
+
+    @Test
+    void testAddWithDecimalLiteral() {
+        // Add/Subtract with decimal are exact, should pass
+        Slot id = scan1.getOutput().get(0);
+        Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Add(id, new DecimalLiteral(new BigDecimal("1.0")))));
+    }
+
+    // ========== tests for isInjectiveCastTo ==========
+
+    @Test
+    void testIntegerWidening() {
+        
Assertions.assertTrue(TinyIntType.INSTANCE.isInjectiveCastTo(IntegerType.INSTANCE));
+        
Assertions.assertTrue(IntegerType.INSTANCE.isInjectiveCastTo(BigIntType.INSTANCE));
+        
Assertions.assertFalse(IntegerType.INSTANCE.isInjectiveCastTo(TinyIntType.INSTANCE));
+        
Assertions.assertFalse(BigIntType.INSTANCE.isInjectiveCastTo(IntegerType.INSTANCE));
+    }
+
+    @Test
+    void testDecimalWidening() {
+        Assertions.assertTrue(DecimalV3Type.createDecimalV3Type(5, 2)
+                .isInjectiveCastTo(DecimalV3Type.createDecimalV3Type(10, 4)));
+        Assertions.assertFalse(DecimalV3Type.createDecimalV3Type(10, 4)
+                .isInjectiveCastTo(DecimalV3Type.createDecimalV3Type(5, 2)));
+    }
+
+    @Test
+    void testIntegralToDecimalWidening() {
+        Assertions.assertTrue(TinyIntType.INSTANCE
+                .isInjectiveCastTo(DecimalV3Type.createDecimalV3Type(10, 0)));
+        // BigInt has 19 digits, DECIMAL(5,0) only has 5 integer digits
+        Assertions.assertFalse(BigIntType.INSTANCE
+                .isInjectiveCastTo(DecimalV3Type.createDecimalV3Type(5, 0)));
+    }
+
+    @Test
+    void testCrossFamilyRejected() {
+        
Assertions.assertFalse(IntegerType.INSTANCE.isInjectiveCastTo(FloatType.INSTANCE));
+        
Assertions.assertFalse(FloatType.INSTANCE.isInjectiveCastTo(IntegerType.INSTANCE));
+        
Assertions.assertFalse(IntegerType.INSTANCE.isInjectiveCastTo(DoubleType.INSTANCE));
+    }
+
+    // ========== tests for canExtractSlot ==========
+
+    @Test
+    void testCanExtractSlotBare() {
+        Slot id = scan1.getOutput().get(0);
+        Assertions.assertTrue(SimplifyAggGroupBy.canExtractSlot(id));
+    }
+
+    @Test
+    void testCanExtractSlotWidening() {
+        Slot id = scan1.getOutput().get(0);
+        // INT->BIGINT is lossless widening
+        Expression cast = new Cast(id, BigIntType.INSTANCE);
+        Assertions.assertTrue(SimplifyAggGroupBy.canExtractSlot(cast));
+    }
+
+    @Test
+    void testCanExtractSlotExplicitCast() {
+        Slot id = scan1.getOutput().get(0);
+        // explicit cast should also be acceptable if lossless
+        Expression cast = new Cast(id, BigIntType.INSTANCE, true);
+        Assertions.assertTrue(SimplifyAggGroupBy.canExtractSlot(cast));
+    }
+
+    @Test
+    void testCanExtractSlotNarrowing() {
+        Slot id = scan1.getOutput().get(0);
+        // INT -> TINYINT is narrowing, should be rejected
+        Expression cast = new Cast(id, TinyIntType.INSTANCE);
+        Assertions.assertFalse(SimplifyAggGroupBy.canExtractSlot(cast));
+    }
+
+    // ========== integration tests via PlanChecker ==========
+
+    @Test
+    void testMultiplyByZeroNotSimplified() {
+        Slot id = scan1.getOutput().get(0);
+        List<NamedExpression> output = ImmutableList.of(id, new 
Count().alias("cnt"));
+        List<Expression> groupBy = ImmutableList.of(id, new Multiply(id, 
Literal.of(0)));
+        LogicalPlan agg = new LogicalPlanBuilder(scan1)
+                .agg(groupBy, output)
+                .build();
+        ConnectContext connectContext = MemoTestUtils.createConnectContext();
+        
connectContext.getSessionVariable().setEnableMaterializedViewRewrite(false);
+        PlanChecker.from(connectContext, agg)
+                .applyTopDown(new SimplifyAggGroupBy())
+                .matchesFromRoot(
+                        logicalAggregate().when(a -> a.equals(agg))
+                );
+    }
+
+    @Test
+    void testNullLiteralNotSimplified() {
+        Slot id = scan1.getOutput().get(0);
+        List<NamedExpression> output = ImmutableList.of(id, new 
Count().alias("cnt"));
+        List<Expression> groupBy = ImmutableList.of(id, new Add(id, 
NullLiteral.INSTANCE));
+        LogicalPlan agg = new LogicalPlanBuilder(scan1)
+                .agg(groupBy, output)
+                .build();
+        ConnectContext connectContext = MemoTestUtils.createConnectContext();
+        
connectContext.getSessionVariable().setEnableMaterializedViewRewrite(false);
+        PlanChecker.from(connectContext, agg)
+                .applyTopDown(new SimplifyAggGroupBy())
+                .matchesFromRoot(
+                        logicalAggregate().when(a -> a.equals(agg))
+                );
+    }
+
+    @Test
+    void testMultiplyDoubleLiteralNotSimplified() {
+        Slot id = scan1.getOutput().get(0);
+        List<NamedExpression> output = ImmutableList.of(id, new 
Count().alias("cnt"));
+        List<Expression> groupBy = ImmutableList.of(id, new Multiply(id, new 
DoubleLiteral(0.1)));
+        LogicalPlan agg = new LogicalPlanBuilder(scan1)
+                .agg(groupBy, output)
+                .build();
+        ConnectContext connectContext = MemoTestUtils.createConnectContext();
+        
connectContext.getSessionVariable().setEnableMaterializedViewRewrite(false);
+        PlanChecker.from(connectContext, agg)
+                .applyTopDown(new SimplifyAggGroupBy())
+                .matchesFromRoot(
+                        logicalAggregate().when(a -> a.equals(agg))
+                );
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to