mihaibudiu commented on code in PR #4961:
URL: https://github.com/apache/calcite/pull/4961#discussion_r3292179690
##########
core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsOnGroupKeysRule.java:
##########
@@ -147,12 +153,118 @@ protected AggregateReduceFunctionsOnGroupKeysRule(Config
config) {
default:
return null;
}
- final int groupIndex = aggregate.getGroupSet().asList().indexOf(arg);
- RexNode ref = RexInputRef.of(groupIndex,
aggregate.getRowType().getFieldList());
- if (!ref.getType().equals(call.getType())) {
- ref = rexBuilder.makeCast(call.getParserPosition(), call.getType(), ref);
+ final List<Integer> argList = call.getArgList();
+ if (argList.size() != 1) {
+ return null;
+ }
+ final int arg = argList.get(0);
+
+ // Case 1: argument directly references a group-by key
+ if (aggregate.getGroupSet().get(arg)) {
+ final int groupIndex = aggregate.getGroupSet().asList().indexOf(arg);
+ RexNode ref = RexInputRef.of(groupIndex,
aggregate.getRowType().getFieldList());
+ if (!ref.getType().equals(call.getType())) {
+ ref = rexBuilder.makeCast(call.getParserPosition(), call.getType(),
ref);
+ }
+ return ref;
+ }
+
+ // Case 2: argument is an expression in a Project below the Aggregate
+ RelNode input = aggregate.getInput();
+ if (input instanceof HepRelVertex) {
+ input = ((HepRelVertex) input).getCurrentRel();
+ }
+ if (!(input instanceof Project)) {
+ return null;
+ }
+ final Project project = (Project) input;
+ if (arg < 0 || arg >= project.getProjects().size()) {
+ return null;
+ }
+ final RexNode expr = project.getProjects().get(arg);
+ if (!RexUtil.isDeterministic(expr)) {
+ return null;
+ }
+ // Check that all columns referenced in the expression are group-by keys.
+ // This ensures that the expression value is constant within each group.
+ final @Nullable RexNode translated =
+ translateToGroupRefs(expr, project, aggregate);
+ if (translated == null) {
+ return null;
+ }
+ if (!translated.getType().equals(call.getType())) {
+ return rexBuilder.makeCast(call.getParserPosition(), call.getType(),
translated);
+ }
+ return translated;
+ }
+
+ /**
+ * Translates an expression so that its {@link RexInputRef}s reference
+ * the group keys of the aggregate rather than the input to the project.
+ *
+ * @return the translated expression, or null if the expression references
+ * columns that are not group-by keys
+ */
+ private static @Nullable RexNode translateToGroupRefs(
+ RexNode expr, Project project, Aggregate aggregate) {
+ final List<RexNode> projects = project.getProjects();
+ final GroupRefTranslator translator = new GroupRefTranslator(projects,
aggregate);
+ final RexNode result = expr.accept(translator);
+ return translator.failed ? null : result;
+ }
+
+ /**
+ * Shuttle that translates input refs to aggregate group key refs.
+ *
+ * <p>For each column reference in the expression being examined:
+ * 1. If the expression is a direct pass-through of a project column,
+ * check if that project column is in the GROUP BY set
+ * 2. If the expression contains references to input columns,
+ * verify that those input columns are in the GROUP BY set
+ * 3. Map to the corresponding group key index in the aggregate
+ *
+ * <p>This ensures the expression references only columns that are constant
+ * within each group.
+ */
+ private static class GroupRefTranslator extends RexShuttle {
+ private final List<RexNode> projects;
+ private final Aggregate aggregate;
+ private boolean failed = false;
+
+ GroupRefTranslator(List<RexNode> projects, Aggregate aggregate) {
+ this.projects = projects;
+ this.aggregate = aggregate;
+ }
+
+ @Override public RexNode visitInputRef(RexInputRef inputRef) {
+ if (failed) {
+ return inputRef;
+ }
+ final int inputIndex = inputRef.getIndex();
+ // Look for a project column that is a direct pass-through of this input.
+ // For example, if a project has SAL=[$5], and the expression references
$5,
+ // we need to map it to the corresponding group key.
+ int projectOutputIndex = -1;
+ for (int i = 0; i < projects.size(); i++) {
+ final RexNode projExpr = projects.get(i);
+ if (projExpr instanceof RexInputRef
+ && ((RexInputRef) projExpr).getIndex() == inputIndex) {
+ projectOutputIndex = i;
+ break;
+ }
+ }
+ // The input column must be available through a project column that is in
+ // the GROUP BY set. If not found, the input is embedded in a computed
+ // expression, which means the optimization cannot proceed safely.
+ if (projectOutputIndex < 0
Review Comment:
can you add a few negative tests that exercise these paths?
Some expressions that have a mix of the group-by columns and other columns?
##########
core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsOnGroupKeysRule.java:
##########
@@ -147,12 +153,118 @@ protected AggregateReduceFunctionsOnGroupKeysRule(Config
config) {
default:
return null;
}
- final int groupIndex = aggregate.getGroupSet().asList().indexOf(arg);
- RexNode ref = RexInputRef.of(groupIndex,
aggregate.getRowType().getFieldList());
- if (!ref.getType().equals(call.getType())) {
- ref = rexBuilder.makeCast(call.getParserPosition(), call.getType(), ref);
+ final List<Integer> argList = call.getArgList();
+ if (argList.size() != 1) {
+ return null;
+ }
+ final int arg = argList.get(0);
+
+ // Case 1: argument directly references a group-by key
+ if (aggregate.getGroupSet().get(arg)) {
+ final int groupIndex = aggregate.getGroupSet().asList().indexOf(arg);
+ RexNode ref = RexInputRef.of(groupIndex,
aggregate.getRowType().getFieldList());
+ if (!ref.getType().equals(call.getType())) {
Review Comment:
do the tests cover this case, when a cast is inserted?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]