snuyanzin commented on code in PR #3837:
URL: https://github.com/apache/calcite/pull/3837#discussion_r1663190224


##########
core/src/main/java/org/apache/calcite/rel/rules/MeasureRules.java:
##########
@@ -0,0 +1,532 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.rel.rules;
+
+import org.apache.calcite.plan.RelOptRule;
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelRule;
+import org.apache.calcite.rel.core.Aggregate;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.core.Filter;
+import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rel.core.Sort;
+import org.apache.calcite.rel.metadata.BuiltInMetadata;
+import org.apache.calcite.rel.metadata.RelMdMeasure;
+import org.apache.calcite.rel.metadata.RelMetadataQuery;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexCorrelVariable;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexLiteral;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexShuttle;
+import org.apache.calcite.rex.RexUtil;
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.fun.SqlInternalOperators;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.tools.RelBuilder;
+import org.apache.calcite.util.MonotonicSupplier;
+import org.apache.calcite.util.Util;
+
+import com.google.common.base.Suppliers;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+
+import org.immutables.value.Value;
+
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+import java.util.function.Supplier;
+
+import static com.google.common.collect.Iterables.getOnlyElement;
+
+/**
+ * Collection of planner rules that deal with measures.
+ *
+ * <p>A typical rule pushes down {@code M2V(measure)}
+ * until it reaches a {@code V2M(expression)}.
+ *
+ * @see org.apache.calcite.sql.fun.SqlInternalOperators#M2V
+ * @see org.apache.calcite.sql.fun.SqlInternalOperators#V2M
+ */
+public abstract class MeasureRules {
+
+  private MeasureRules() { }
+
+  /** Returns all rules. */
+  public static Iterable<? extends RelOptRule> rules() {
+    return ImmutableList.of(AGGREGATE2, PROJECT, PROJECT_SORT);
+  }
+
+  /** Rule that matches an {@link Aggregate}
+   * that contains a {@code M2V} call
+   * and pushes down the {@code M2V} call into a {@link Project}. */
+  public static final RelOptRule AGGREGATE =
+      AggregateMeasureRuleConfig.DEFAULT
+          .toRule();
+
+  /** Configuration for {@link AggregateMeasureRule}. */
+  @Value.Immutable
+  public interface AggregateMeasureRuleConfig extends RelRule.Config {
+    AggregateMeasureRuleConfig DEFAULT = 
ImmutableAggregateMeasureRuleConfig.of()
+        .withOperandSupplier(b ->
+            b.operand(Aggregate.class)
+                .predicate(b2 ->
+                    b2.getAggCallList().stream().anyMatch(c ->
+                        c.getAggregation() == SqlInternalOperators.AGG_M2V))
+                .anyInputs());
+
+    @Override default AggregateMeasureRule toRule() {
+      return new AggregateMeasureRule(this);
+    }
+  }
+
+  /** Rule that matches an {@link Aggregate} with at least one call to
+   * {@link SqlInternalOperators#AGG_M2V} and converts those calls
+   * to {@link SqlInternalOperators#M2X}.
+   *
+   * <p>Converts
+   *
+   * <pre>{@code
+   * Aggregate(a, b, AGG_M2V(c), SUM(d), AGG_M2V(e))
+   *   R
+   * }</pre>
+   *
+   * <p>to
+   *
+   * <pre>{@code
+   * Aggregate(a, b, SINGLE_VALUE(c), SUM(d), SINGLE_VALUE(e))
+   *   Project(a, b, c, d, e, M2X(c, SAME_PARTITION(a, b)),
+   *        M2X(e, SAME_PARTITION(a, b)))
+   *     R
+   * }</pre>
+   *
+   * <p>We rely on those {@code M2X} calls being pushed down until they merge
+   * with {@code V2M2} and {@link ProjectMeasureRule} can apply.
+   *
+   * @see MeasureRules#AGGREGATE
+   * @see AggregateMeasureRuleConfig */
+  @SuppressWarnings("WeakerAccess")
+  public static class AggregateMeasureRule
+      extends RelRule<AggregateMeasureRuleConfig>
+      implements TransformationRule {
+    /** Creates a AggregateMeasureRule. */
+    protected AggregateMeasureRule(AggregateMeasureRuleConfig config) {
+      super(config);
+    }
+
+    @Override public void onMatch(RelOptRuleCall call) {
+      final Aggregate aggregate = call.rel(0);
+      final RelBuilder b = call.builder();
+      b.push(aggregate.getInput());
+      final List<Function<RelBuilder, RelBuilder.AggCall>> aggCallList =
+          new ArrayList<>();
+      final List<RexNode> extraProjects = new ArrayList<>();
+      aggregate.getAggCallList().forEach(c -> {
+        if (c.getAggregation().kind == SqlKind.AGG_M2V) {
+          final int arg = getOnlyElement(c.getArgList());
+          final int i = b.fields().size() + extraProjects.size();
+          extraProjects.add(
+              b.call(SqlInternalOperators.M2X, b.field(arg),
+                  b.call(SqlInternalOperators.SAME_PARTITION,
+                      b.fields(aggregate.getGroupSet()))));
+          aggCallList.add(b2 ->
+              b2.aggregateCall(SqlStdOperatorTable.SINGLE_VALUE, b2.field(i)));
+        } else {
+          aggCallList.add(b2 -> b2.aggregateCall(c));
+        }
+      });
+      b.projectPlus(extraProjects);
+      b.aggregate(
+          b.groupKey(aggregate.getGroupSet(), aggregate.groupSets),
+          bind(aggCallList).apply(b));
+      call.transformTo(b.build());
+    }
+
+    /** Converts a list of functions into a function that returns a list.
+     * It is named after the Monad bind operator. */
+    private static <T, E> Function<T, List<E>> bind(List<Function<T, E>> list) 
{
+      return t -> {
+        final ImmutableList.Builder<E> builder = ImmutableList.builder();
+        list.forEach(f -> builder.add(f.apply(t)));
+        return builder.build();
+      };
+    }
+  }
+
+  /** Rule that merges an {@link Aggregate}
+   * onto a {@code Project} that contains a {@code M2X} call. */
+  // TODO rename field and class
+  public static final RelOptRule PROJECT =
+      ProjectMeasureRuleConfig.DEFAULT
+          .toRule();
+
+  /** Configuration for {@link ProjectMeasureRule}. */
+  @Value.Immutable
+  public interface ProjectMeasureRuleConfig extends RelRule.Config {
+    ProjectMeasureRuleConfig DEFAULT = ImmutableProjectMeasureRuleConfig.of()
+        .withOperandSupplier(b ->
+            b.operand(Aggregate.class)
+                .predicate(aggregate ->
+                    aggregate.getAggCallList().stream().allMatch(c ->
+                        c.getAggregation() == 
SqlStdOperatorTable.SINGLE_VALUE))
+                .oneInput(b2 ->
+                    b2.operand(Project.class)
+                        .predicate(RexUtil.find(SqlKind.V2M)::inProject)
+                        .anyInputs()));
+
+    @Override default ProjectMeasureRule toRule() {
+      return new ProjectMeasureRule(this);
+    }
+  }
+
+  /** Rule that matches an {@link Aggregate}
+   * that contains a {@code M2V} call
+   * and pushes down the {@code M2V} call into a {@link Project}. */
+  public static final RelOptRule AGGREGATE2 =
+      AggregateMeasure2RuleConfig.DEFAULT
+          .toRule();
+
+  /** Configuration for {@link AggregateMeasure2Rule}. */
+  @Value.Immutable
+  public interface AggregateMeasure2RuleConfig extends RelRule.Config {
+    AggregateMeasure2RuleConfig DEFAULT = 
ImmutableAggregateMeasure2RuleConfig.of()
+        .withOperandSupplier(b ->
+            b.operand(Aggregate.class)
+                .predicate(b2 ->
+                    b2.getAggCallList().stream().anyMatch(c ->
+                        c.getAggregation() == SqlInternalOperators.AGG_M2V))
+                .anyInputs());
+
+    @Override default AggregateMeasure2Rule toRule() {
+      return new AggregateMeasure2Rule(this);
+    }
+  }
+
+  /** Rule that matches an {@link Aggregate} with at least one call to
+   * {@link SqlInternalOperators#AGG_M2V} and expands these calls by
+   * asking the measure for its expression.
+   *
+   * <p>Converts
+   *
+   * <pre>{@code
+   * Aggregate(a, b, AGG_M2V(c), SUM(d), AGG_M2V(e))
+   *   R
+   * }</pre>
+   *
+   * <p>to
+   *
+   * <pre>{@code
+   * Project(a, b, RexSubQuery(...), sum_d, RexSubQuery(...))
+   *   Aggregate(a, b, SUM(d) AS sum_d)
+   *     R
+   * }</pre>
+   *
+   * <p>We will optimize those {@link org.apache.calcite.rex.RexSubQuery}
+   * later. For example,
+   *
+   * <pre>{@code
+   * SELECT deptno,
+   *     (SELECT AVG(sal)
+   *      FROM emp
+   *      WHERE deptno = e.deptno)
+   * FROM Emp
+   * }</pre>
+   *
+   * <p>will become
+   *
+   * <pre>{@code
+   * SELECT deptno, AVG(sal)
+   * FROM emp
+   * WHERE deptno = e.deptno
+   * }</pre>
+   *
+   * @see org.apache.calcite.rel.metadata.RelMdMeasure
+   * @see MeasureRules#AGGREGATE2
+   * @see AggregateMeasure2RuleConfig */
+  @SuppressWarnings("WeakerAccess")
+  public static class AggregateMeasure2Rule
+      extends RelRule<AggregateMeasure2RuleConfig>
+      implements TransformationRule {
+    /** Creates an AggregateMeasure2Rule. */
+    protected AggregateMeasure2Rule(AggregateMeasure2RuleConfig config) {
+      super(config);
+    }
+
+    @Override public void onMatch(RelOptRuleCall call) {
+      final RelMetadataQuery mq = call.getMetadataQuery();
+      final Aggregate aggregate = call.rel(0);
+      final RelBuilder b = call.builder();
+      b.push(aggregate.getInput());
+      final MonotonicSupplier<RexCorrelVariable> holder =
+          MonotonicSupplier.empty();
+      final List<Function<RelBuilder, RelBuilder.AggCall>> aggCallList =
+          new ArrayList<>();
+      final List<Function<RelBuilder, RexNode>> projects = new ArrayList<>();
+      b.variable(holder)
+          .let(b2 -> {
+            aggregate.getGroupSet().forEachInt(i ->
+                projects.add(b4 -> b4.field(i)));
+            // Memoize the RelBuilder so we don't create more than one.
+            @SuppressWarnings("FunctionalExpressionCanBeFolded")
+            final Supplier<RelBuilder> builderSupplier =
+                Suppliers.memoize(call::builder)::get;
+            final BuiltInMetadata.Measure.Context context =
+                RelMdMeasure.Contexts.forAggregate(aggregate, builderSupplier, 
holder.get());
+            aggregate.getAggCallList().forEach(c -> {
+              if (c.getAggregation().kind == SqlKind.AGG_M2V) {
+                final int arg = getOnlyElement(c.getArgList());
+                aggCallList.add(b3 ->
+                    b3.aggregateCall(SqlInternalOperators.AGG_M2M,
+                        b3.fields(c.getArgList()))
+                        .filter(c.filterArg < 0 ? null : 
b3.field(c.filterArg)));
+                final BuiltInMetadata.Measure.Context context2 =
+                    new RelMdMeasure.DelegatingContext(context) {
+                      @Override public List<RexNode> getFilters(RelBuilder b) {
+                        final ImmutableList.Builder<RexNode> builder =
+                            ImmutableList.builder();
+                        builder.addAll(super.getFilters(b));
+                        if (c.filterArg >= 0) {
+                          builder.add(b.field(c.filterArg));
+                        }
+                        return builder.build();
+                      }
+                    };
+                projects.add(b4 -> mq.expand(b4.peek(), arg, context2));
+              } else {
+                final int i =
+                    aggregate.getGroupSet().cardinality() + aggCallList.size();
+                aggCallList.add(b3 ->
+                    b3.aggregateCall(c)
+                        .filter(c.filterArg < 0 ? null : 
b3.field(c.filterArg)));
+                projects.add(b4 -> b4.field(i));
+              }
+            });
+            return b2;
+          });
+      b.aggregate(b.groupKey(aggregate.getGroupSet(), aggregate.groupSets),
+          bind(aggCallList).apply(b));
+      b.project(bind(projects).apply(b), 
aggregate.getRowType().getFieldNames(),
+          false, ImmutableSet.of(holder.get().id));
+      call.transformTo(b.build());
+    }
+
+    /** Converts a list of functions into a function that returns a list.
+     * It is named after the Monad bind operator. */
+    private static <T, E> Function<T, List<E>> bind(List<Function<T, E>> list) 
{
+      return t -> {
+        final ImmutableList.Builder<E> builder = ImmutableList.builder();
+        list.forEach(f -> builder.add(f.apply(t)));
+        return builder.build();
+      };
+    }
+  }
+
+  /** Rule that merges an {@link Aggregate} onto a {@link Project}.
+   *
+   * <p>Converts
+   *
+   * <pre>{@code
+   * Aggregate(a, b, SINGLE_VALUE(d) AS e)
+   *   Project(a, b, M2X(M2V(SUM(c) + 1), SAME_PARTITION(a, b)) AS d)
+   *     R
+   * }</pre>
+   *
+   * <p>to
+   *
+   * <pre>{@code
+   * Project(a, b, sum_c + 1 AS e),
+   *   Aggregate(a, b, SUM(c) AS sum_c)
+   *     R
+   * }</pre>
+   *
+   * @see ProjectMeasureRuleConfig */
+  @SuppressWarnings("WeakerAccess")
+  public static class ProjectMeasureRule
+      extends RelRule<ProjectMeasureRuleConfig>
+      implements TransformationRule {
+    /** Creates a ProjectMeasureRule. */
+    protected ProjectMeasureRule(ProjectMeasureRuleConfig config) {
+      super(config);
+    }
+
+    @Override public void onMatch(RelOptRuleCall call) {
+      final Aggregate aggregate = call.rel(0);
+      final Project project = call.rel(1);
+      final RelBuilder b = call.builder();
+      b.push(project)
+          .aggregateRex(
+              b.groupKey(aggregate.getGroupSet(), aggregate.getGroupSets()),
+              true,
+              Util.transform(aggregate.getAggCallList(),
+                  aggregateCall -> toRex(aggregateCall, project)));
+      call.transformTo(b.build());
+    }
+
+    @SuppressWarnings("SwitchStatementWithTooFewBranches")
+    private static RexNode toRex(AggregateCall aggregateCall, Project project) 
{
+      switch (aggregateCall.getAggregation().kind) {
+      case SINGLE_VALUE:
+        final int arg = getOnlyElement(aggregateCall.getArgList());
+        final RexNode e = project.getProjects().get(arg);
+        switch (e.getKind()) {
+        case M2X:
+          final RexCall callM2x = (RexCall) e;
+          switch (callM2x.operands.get(0).getKind()) {
+          case V2M:
+            final RexCall callV2m = (RexCall) callM2x.operands.get(0);
+            return callV2m.operands.get(0);
+          default:
+            throw new UnsupportedOperationException();
+          }
+        default:
+          throw new UnsupportedOperationException();
+        }
+      default:
+        throw new UnsupportedOperationException();
+      }
+    }
+  }
+
+  /** Rule that matches a {@link Filter} that contains a {@code M2V} call
+   * on top of a {@link Sort} and pushes down the {@code M2V} call. */
+  public static final RelOptRule FILTER_SORT =
+      FilterSortMeasureRuleConfig.DEFAULT
+          .as(FilterSortMeasureRuleConfig.class)
+          .toRule();
+
+  /** Configuration for {@link FilterSortMeasureRule}. */
+  @Value.Immutable
+  public interface FilterSortMeasureRuleConfig extends RelRule.Config {
+    FilterSortMeasureRuleConfig DEFAULT = 
ImmutableFilterSortMeasureRuleConfig.of()
+        .withOperandSupplier(b ->
+            b.operand(Filter.class)
+                .oneInput(b2 -> b2.operand(Sort.class)
+                    .anyInputs()));
+
+    @Override default FilterSortMeasureRule toRule() {
+      return new FilterSortMeasureRule(this);
+    }
+  }
+
+  /** Rule that ...
+   *
+   * @see MeasureRules#FILTER_SORT
+   * @see FilterSortMeasureRuleConfig */
+  @SuppressWarnings("WeakerAccess")
+  public static class FilterSortMeasureRule
+      extends RelRule<FilterSortMeasureRuleConfig>
+      implements TransformationRule {
+    /** Creates a FilterSortMeasureRule. */
+    protected FilterSortMeasureRule(FilterSortMeasureRuleConfig config) {
+      super(config);
+    }
+
+    @Override public void onMatch(RelOptRuleCall call) {
+      final Filter filter = call.rel(0);
+      final RexNode condition = filter.getCondition();
+      if (condition.equals(filter.getCondition())) {
+        return;
+      }
+      final RelBuilder relBuilder =
+          relBuilderFactory.create(filter.getCluster(), null);
+      relBuilder.push(filter.getInput())
+          .filter(condition);
+      call.transformTo(relBuilder.build());
+    }
+  }
+
+  /** Rule that matches a {@link Project} that contains a {@code M2V} call
+   * on top of a {@link Sort} and pushes down the {@code M2V} call. */
+  public static final RelOptRule PROJECT_SORT =
+      ProjectSortMeasureRuleConfig.DEFAULT
+          .as(ProjectSortMeasureRuleConfig.class)
+          .toRule();
+
+  /** Rule that ...

Review Comment:
   
   feel like unfinished description...
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to