This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch branch-4.1 in repository https://gitbox.apache.org/repos/asf/doris.git
commit 0c6de75fe66f02f34811b24beafcd9763d4f9e81 Author: minghong <[email protected]> AuthorDate: Fri Mar 13 23:14:05 2026 +0800 branch-4.1 [opt](Nereids) strip redundant widening integer cast in SumLiteralRewrite (#61224) (#61308) pick #61224 ### What problem does this PR solve? SumLiteralRewrite transforms SUM(expr +/- literal) into SUM(expr) +/- literal * COUNT(expr). When type coercion has introduced an implicit widening cast (e.g. CAST(smallint_col AS INT)), the rewritten SUM/COUNT still operates on the wider type, forcing unnecessary wider data reads. This is redundant because SUM always returns BIGINT for any integer input (TINYINT/SMALLINT/INT/BIGINT). Strip implicit widening integer casts in extractSumLiteral() so the aggregate operates on the original narrow column directly. This benefits ClickBench Q29-style queries where SUM(col), SUM(col+1), SUM(col+2) share a narrow integer column — after stripping the cast, SUM(col+1) and SUM(col+2) reuse the existing SUM(col). ### What problem does this PR solve? Issue Number: close #xxx Related PR: #xxx Problem Summary: ### Release note None ### Check List (For Author) - Test <!-- At least one of them must be included. --> - [ ] Regression test - [ ] Unit Test - [ ] Manual test (add detailed scripts or steps below) - [ ] No need to test or manual test. Explain why: - [ ] This is a refactor/code format and no logic has been changed. - [ ] Previous test can cover this change. - [ ] No code files have been changed. - [ ] Other reason <!-- Add your reason? --> - Behavior changed: - [ ] No. - [ ] Yes. <!-- Explain the behavior change --> - Does this need documentation? - [ ] No. - [ ] Yes. <!-- Add document PR link here. eg: https://github.com/apache/doris-website/pull/1214 --> ### Check List (For Reviewer who merge this PR) - [ ] Confirm the release note - [ ] Confirm test cases - [ ] Confirm document - [ ] Add branch pick label <!-- Add branch pick label that this PR should merge into --> --- .../nereids/rules/rewrite/SumLiteralRewrite.java | 33 ++++++++++++ .../rules/rewrite/SumLiteralRewriteTest.java | 60 ++++++++++++++++++++++ .../data/nereids_rules_p0/sumRewrite.out | 8 +-- 3 files changed, 97 insertions(+), 4 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewrite.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewrite.java index 09be00a5819..cf983c03133 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewrite.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewrite.java @@ -23,6 +23,7 @@ import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.BinaryArithmetic; +import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Multiply; import org.apache.doris.nereids.trees.expressions.NamedExpression; @@ -204,10 +205,42 @@ public class SumLiteralRewrite extends OneRewriteRuleFactory { // only support integer or float types return null; } + // Strip redundant widening integer cast introduced by type coercion. + // e.g. SUM(CAST(smallint_col AS INT) + 1) → after rewrite becomes SUM(CAST(smallint_col AS INT)). + // Since SUM always returns BIGINT for any integer input, CAST(smallint→int) is unnecessary + // and forces wider data reads. Strip it so we get SUM(smallint_col) directly. + left = stripWideningIntegerCast(left); SumInfo info = new SumInfo(left, ((Sum) func).isDistinct(), ((Sum) func).isAlwaysNullable()); return Pair.of(namedExpression, Pair.of(info, (Literal) right)); } + /** + * Strip a widening integer cast that is redundant for SUM/COUNT. + * For example, CAST(smallint_col AS INT) → smallint_col. + * + * This is safe because: + * - SUM returns BIGINT for all integer inputs (TINYINT/SMALLINT/INT/BIGINT), + * so widening the input before aggregation does not change the result. + * - COUNT just counts non-null values, unaffected by widening. + * + * Only implicit (type-coercion) casts between integer-like types are stripped. + */ + private static Expression stripWideningIntegerCast(Expression expr) { + if (!(expr instanceof Cast)) { + return expr; + } + Cast cast = (Cast) expr; + if (cast.isExplicitType()) { + return expr; + } + Expression inner = cast.child(); + if (inner.getDataType().isIntegerLikeType() && cast.getDataType().isIntegerLikeType() + && inner.getDataType().width() <= cast.getDataType().width()) { + return inner; + } + return expr; + } + static class SumInfo { Expression expr; boolean isDistinct; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewriteTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewriteTest.java index 19ea7b864fb..5b918c62a59 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewriteTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewriteTest.java @@ -19,12 +19,14 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.Subtract; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PlanChecker; @@ -143,4 +145,62 @@ class SumLiteralRewriteTest implements MemoPatternMatchSupported { .matches(logicalAggregate().when(p -> p.getOutputs().size() == 3)); } + + @Test + void testStripWideningIntegerCast() { + Slot slot1 = scan1.getOutput().get(0); + // Simulate type coercion's implicit widening cast: CAST(int_col AS BIGINT) + Cast castSlot = new Cast(slot1, BigIntType.INSTANCE); + Alias add1 = new Alias(new Sum(new Add(castSlot, Literal.of(1)))); + Alias add2 = new Alias(new Sum(new Add(castSlot, Literal.of(2)))); + LogicalAggregate<?> agg = new LogicalAggregate<>( + ImmutableList.of(), ImmutableList.of(add1, add2), scan1); + PlanChecker.from(MemoTestUtils.createConnectContext(), agg) + .applyTopDown(ImmutableList.of(new SumLiteralRewrite().build())) + .printlnTree() + // After stripping the implicit widening cast, Sum and Count should use + // slot1 directly (not Cast(slot1 AS BIGINT)), so no Cast in aggregate outputs + .matches(logicalAggregate().when(a -> + a.getOutputExpressions().stream().noneMatch( + e -> e.anyMatch(expr -> expr instanceof Cast)))); + + // Verify explicit cast is NOT stripped + Cast explicitCast = new Cast(slot1, BigIntType.INSTANCE, true); + Alias addExplicit1 = new Alias(new Sum(new Add(explicitCast, Literal.of(1)))); + Alias addExplicit2 = new Alias(new Sum(new Add(explicitCast, Literal.of(2)))); + agg = new LogicalAggregate<>( + ImmutableList.of(), ImmutableList.of(addExplicit1, addExplicit2), scan1); + PlanChecker.from(MemoTestUtils.createConnectContext(), agg) + .applyTopDown(ImmutableList.of(new SumLiteralRewrite().build())) + .printlnTree() + // Explicit cast should be preserved — aggregate outputs should still contain Cast + .matches(logicalAggregate().when(a -> + a.getOutputExpressions().stream().anyMatch( + e -> e.anyMatch(expr -> expr instanceof Cast)))); + } + + @Test + void testStripWideningCastWithExistingSum() { + // Simulates ClickBench Q29: SELECT SUM(col), SUM(col+1), SUM(col+2) + // where col is a narrow integer type and type coercion introduces implicit widening cast. + Slot slot1 = scan1.getOutput().get(0); + // Pre-existing plain SUM(slot) — no cast, no literal + Alias sum = new Alias(new Sum(slot1)); + // Simulate type coercion widening: SUM(CAST(int_col AS BIGINT) + 1) etc. + Cast castSlot = new Cast(slot1, BigIntType.INSTANCE); + Alias add1 = new Alias(new Sum(new Add(castSlot, Literal.of(1)))); + Alias add2 = new Alias(new Sum(new Add(castSlot, Literal.of(2)))); + LogicalAggregate<?> agg = new LogicalAggregate<>( + ImmutableList.of(), ImmutableList.of(sum, add1, add2), scan1); + PlanChecker.from(MemoTestUtils.createConnectContext(), agg) + .applyTopDown(ImmutableList.of(new SumLiteralRewrite().build())) + .printlnTree() + // After stripping widening cast, the base expr of SUM(CAST(slot AS BIGINT) + n) + // becomes slot — matching the pre-existing SUM(slot). Rewrite reuses it and only + // adds COUNT(slot). Aggregate outputs: sum(slot) + count(slot) = 2. + .matches(logicalAggregate().when(a -> + a.getOutputExpressions().size() == 2 + && a.getOutputExpressions().stream().noneMatch( + e -> e.anyMatch(expr -> expr instanceof Cast)))); + } } diff --git a/regression-test/data/nereids_rules_p0/sumRewrite.out b/regression-test/data/nereids_rules_p0/sumRewrite.out index 356fad37763..265ccae7dea 100644 --- a/regression-test/data/nereids_rules_p0/sumRewrite.out +++ b/regression-test/data/nereids_rules_p0/sumRewrite.out @@ -268,10 +268,10 @@ PhysicalResultSink -- !sum_null_and_not_null_shape -- PhysicalResultSink ---PhysicalProject[(sum(cast(id as BIGINT)) + (count(cast(id as BIGINT)) * 1)) AS `sum(id + 1)`, (sum(cast(id as BIGINT)) - (count(cast(id as BIGINT)) * 1)) AS `sum(id - 1)`, (sum(cast(not_null_id as BIGINT)) + (count(cast(not_null_id as BIGINT)) * 1)) AS `sum(not_null_id + 1)`, (sum(cast(not_null_id as BIGINT)) - (count(cast(not_null_id as BIGINT)) * 1)) AS `sum(not_null_id - 1)`, sum(id), sum(not_null_id)] -----hashAgg[GLOBAL, groupByExpr=(), outputExpr=(count(cast(id as BIGINT)) AS `count(cast(id as BIGINT))`, count(cast(not_null_id as BIGINT)) AS `count(cast(not_null_id as BIGINT))`, sum(cast(id as BIGINT)) AS `sum(cast(id as BIGINT))`, sum(cast(not_null_id as BIGINT)) AS `sum(cast(not_null_id as BIGINT))`, sum(id) AS `sum(id)`, sum(not_null_id) AS `sum(not_null_id)`)] -------hashAgg[LOCAL, groupByExpr=(), outputExpr=(partial_count(cast(id as BIGINT)) AS `partial_count(cast(id as BIGINT))`, partial_count(cast(not_null_id as BIGINT)) AS `partial_count(cast(not_null_id as BIGINT))`, partial_sum(cast(id as BIGINT)) AS `partial_sum(cast(id as BIGINT))`, partial_sum(cast(not_null_id as BIGINT)) AS `partial_sum(cast(not_null_id as BIGINT))`, partial_sum(id) AS `partial_sum(id)`, partial_sum(not_null_id) AS `partial_sum(not_null_id)`)] ---------PhysicalProject[cast(id as BIGINT) AS `cast(id as BIGINT)`, cast(not_null_id as BIGINT) AS `cast(not_null_id as BIGINT)`, sr.id, sr.not_null_id] +--PhysicalProject[(sum(id) + (count(id) * 1)) AS `sum(id + 1)`, (sum(id) - (count(id) * 1)) AS `sum(id - 1)`, (sum(not_null_id) + (count(not_null_id) * 1)) AS `sum(not_null_id + 1)`, (sum(not_null_id) - (count(not_null_id) * 1)) AS `sum(not_null_id - 1)`, sum(id), sum(not_null_id)] +----hashAgg[GLOBAL, groupByExpr=(), outputExpr=(count(id) AS `count(id)`, count(not_null_id) AS `count(not_null_id)`, sum(id) AS `sum(id)`, sum(not_null_id) AS `sum(not_null_id)`)] +------hashAgg[LOCAL, groupByExpr=(), outputExpr=(partial_count(id) AS `partial_count(id)`, partial_count(not_null_id) AS `partial_count(not_null_id)`, partial_sum(id) AS `partial_sum(id)`, partial_sum(not_null_id) AS `partial_sum(not_null_id)`)] +--------PhysicalProject[sr.id, sr.not_null_id] ----------PhysicalOlapScan[sr] -- !sum_null_and_not_null_result -- --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
