This is an automated email from the ASF dual-hosted git repository.
snuyanzin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new a1b9e42df9e [FLINK-35751][table] Migrate
SplitPythonConditionFromCorrelateRule to java
a1b9e42df9e is described below
commit a1b9e42df9e7b6d3a398d1038063e6b7396891fe
Author: Jacky Lau <[email protected]>
AuthorDate: Mon Nov 25 01:56:56 2024 +0800
[FLINK-35751][table] Migrate SplitPythonConditionFromCorrelateRule to java
---
.../logical/CalcPythonCorrelateTransposeRule.java | 4 +-
.../SplitPythonConditionFromCorrelateRule.java | 223 +++++++++++++++
.../flink/table/planner/plan/utils/PythonUtil.java | 316 +++++++++++++++++++++
.../SplitPythonConditionFromCorrelateRule.scala | 148 ----------
.../table/planner/plan/utils/PythonUtil.scala | 243 ----------------
5 files changed, 541 insertions(+), 393 deletions(-)
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRule.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRule.java
index 600ccc951f1..0f12cc82328 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRule.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRule.java
@@ -78,8 +78,8 @@ public class CalcPythonCorrelateTransposeRule extends
RelOptRule {
StreamPhysicalCorrelateRule.getTableScan(mergedCalc);
RexProgram mergedCalcProgram = mergedCalc.getProgram();
- InputRefRewriter inputRefRewriter =
- new InputRefRewriter(
+ SplitPythonConditionFromCorrelateRule.InputRefRewriter
inputRefRewriter =
+ new SplitPythonConditionFromCorrelateRule.InputRefRewriter(
correlate.getRowType().getFieldCount()
- mergedCalc.getRowType().getFieldCount());
List<RexNode> correlateFilters =
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/SplitPythonConditionFromCorrelateRule.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/SplitPythonConditionFromCorrelateRule.java
new file mode 100644
index 00000000000..bcfe4a43660
--- /dev/null
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/SplitPythonConditionFromCorrelateRule.java
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.logical;
+
+import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc;
+import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCorrelate;
+import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalRel;
+import
org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan;
+import
org.apache.flink.table.planner.plan.rules.physical.stream.StreamPhysicalCorrelateRule;
+import org.apache.flink.table.planner.plan.utils.RexDefaultVisitor;
+
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelOptUtil;
+import org.apache.calcite.plan.RelRule;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexProgram;
+import org.apache.calcite.rex.RexProgramBuilder;
+import org.apache.calcite.rex.RexUtil;
+import org.immutables.value.Value;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+import static
org.apache.flink.table.planner.plan.utils.PythonUtil.containsPythonCall;
+import static
org.apache.flink.table.planner.plan.utils.PythonUtil.isNonPythonCall;
+
+/**
+ * Rule will split a {@link FlinkLogicalCalc} which is the upstream of a {@link
+ * FlinkLogicalCorrelate} and contains Python Functions in condition into two
{@link
+ * FlinkLogicalCalc}s. One of the {@link FlinkLogicalCalc} without python
function condition is the
+ * upstream of the {@link FlinkLogicalCorrelate}, but the other {@link
FlinkLogicalCalc} with python
+ * function conditions is the downstream of the {@link FlinkLogicalCorrelate}.
Currently, only inner
+ * join is supported.
+ *
+ * <p>After this rule is applied, there will be no Python Functions in the
condition of the upstream
+ * {@link FlinkLogicalCalc}.
+ */
[email protected]
+public class SplitPythonConditionFromCorrelateRule
+ extends RelRule<
+
SplitPythonConditionFromCorrelateRule.SplitPythonConditionFromCorrelateRuleConfig>
{
+
+ public static final SplitPythonConditionFromCorrelateRule INSTANCE =
+
SplitPythonConditionFromCorrelateRule.SplitPythonConditionFromCorrelateRuleConfig
+ .DEFAULT
+ .toRule();
+
+ private SplitPythonConditionFromCorrelateRule(
+ SplitPythonConditionFromCorrelateRuleConfig config) {
+ super(config);
+ }
+
+ @Override
+ public boolean matches(RelOptRuleCall call) {
+ FlinkLogicalCorrelate correlate = call.rel(0);
+ FlinkLogicalCalc right = call.rel(2);
+ JoinRelType joinType = correlate.getJoinType();
+ FlinkLogicalCalc mergedCalc =
StreamPhysicalCorrelateRule.getMergedCalc(right);
+ FlinkLogicalTableFunctionScan tableScan =
+ StreamPhysicalCorrelateRule.getTableScan(mergedCalc);
+
+ return joinType == JoinRelType.INNER
+ && isNonPythonCall(tableScan.getCall())
+ && mergedCalc.getProgram() != null
+ && mergedCalc.getProgram().getCondition() != null
+ && containsPythonCall(
+ mergedCalc
+ .getProgram()
+
.expandLocalRef(mergedCalc.getProgram().getCondition()),
+ null);
+ }
+
+ @Override
+ public void onMatch(RelOptRuleCall call) {
+ FlinkLogicalCorrelate correlate = call.rel(0);
+ FlinkLogicalCalc right = call.rel(2);
+ RexBuilder rexBuilder = call.builder().getRexBuilder();
+ FlinkLogicalCalc mergedCalc =
StreamPhysicalCorrelateRule.getMergedCalc(right);
+ RexProgram mergedCalcProgram = mergedCalc.getProgram();
+ RelNode input = mergedCalc.getInput();
+
+ List<RexNode> correlateFilters =
+ RelOptUtil.conjunctions(
+
mergedCalcProgram.expandLocalRef(mergedCalcProgram.getCondition()));
+
+ List<RexNode> remainingFilters =
+ correlateFilters.stream()
+ .filter(filter -> !containsPythonCall(filter))
+ .collect(Collectors.toList());
+
+ RexNode bottomCalcCondition = RexUtil.composeConjunction(rexBuilder,
remainingFilters);
+
+ FlinkLogicalCalc newBottomCalc =
+ new FlinkLogicalCalc(
+ mergedCalc.getCluster(),
+ mergedCalc.getTraitSet(),
+ input,
+ RexProgram.create(
+ input.getRowType(),
+ mergedCalcProgram.getProjectList(),
+ bottomCalcCondition,
+ mergedCalc.getRowType(),
+ rexBuilder));
+
+ FlinkLogicalCorrelate newCorrelate =
+ new FlinkLogicalCorrelate(
+ correlate.getCluster(),
+ correlate.getTraitSet(),
+ correlate.getLeft(),
+ newBottomCalc,
+ correlate.getCorrelationId(),
+ correlate.getRequiredColumns(),
+ correlate.getJoinType());
+
+ InputRefRewriter inputRefRewriter =
+ new InputRefRewriter(
+ correlate.getRowType().getFieldCount()
+ - mergedCalc.getRowType().getFieldCount());
+
+ List<RexNode> pythonFilters =
+ correlateFilters.stream()
+ .filter(filter -> containsPythonCall(filter))
+ .map(filter -> filter.accept(inputRefRewriter))
+ .collect(Collectors.toList());
+
+ RexNode topCalcCondition = RexUtil.composeConjunction(rexBuilder,
pythonFilters);
+
+ RexProgram rexProgram =
+ new RexProgramBuilder(newCorrelate.getRowType(),
rexBuilder).getProgram();
+ FlinkLogicalCalc newTopCalc =
+ new FlinkLogicalCalc(
+ newCorrelate.getCluster(),
+ newCorrelate.getTraitSet(),
+ newCorrelate,
+ RexProgram.create(
+ newCorrelate.getRowType(),
+ rexProgram.getExprList(),
+ topCalcCondition,
+ newCorrelate.getRowType(),
+ rexBuilder));
+
+ call.transformTo(newTopCalc);
+ }
+
+ /** Rule configuration. */
+ @Value.Immutable(singleton = false)
+ public interface SplitPythonConditionFromCorrelateRuleConfig extends
RelRule.Config {
+
SplitPythonConditionFromCorrelateRule.SplitPythonConditionFromCorrelateRuleConfig
DEFAULT =
+ ImmutableSplitPythonConditionFromCorrelateRule
+ .SplitPythonConditionFromCorrelateRuleConfig.builder()
+ .build()
+ .withOperandSupplier(
+ b0 ->
+ b0.operand(FlinkLogicalCorrelate.class)
+ .inputs(
+ b1 ->
+
b1.operand(FlinkLogicalRel.class)
+
.anyInputs(),
+ b2 ->
+
b2.operand(FlinkLogicalCalc.class)
+
.anyInputs()))
+
.withDescription("SplitPythonConditionFromCorrelateRule");
+
+ @Override
+ default SplitPythonConditionFromCorrelateRule toRule() {
+ return new SplitPythonConditionFromCorrelateRule(this);
+ }
+ }
+
+ /**
+ * Because the inputRef is from the upstream calc node of the correlate
node, so after the
+ * inputRef is pushed to the downstream calc node of the correlate node,
the inputRef need to
+ * rewrite the index.
+ */
+ static class InputRefRewriter extends RexDefaultVisitor<RexNode> {
+ private final int offset;
+
+ /** @param offset the start offset of the inputRef in the downstream
calc. */
+ public InputRefRewriter(int offset) {
+ this.offset = offset;
+ }
+
+ @Override
+ public RexNode visitInputRef(RexInputRef inputRef) {
+ return new RexInputRef(inputRef.getIndex() + offset,
inputRef.getType());
+ }
+
+ @Override
+ public RexNode visitCall(RexCall call) {
+ return call.clone(
+ call.getType(),
+ call.getOperands().stream()
+ .map(o -> o.accept(this))
+ .collect(Collectors.toList()));
+ }
+
+ @Override
+ public RexNode visitNode(RexNode rexNode) {
+ return rexNode;
+ }
+ }
+}
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/PythonUtil.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/PythonUtil.java
new file mode 100644
index 00000000000..65efb05f430
--- /dev/null
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/PythonUtil.java
@@ -0,0 +1,316 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.utils;
+
+import org.apache.flink.table.functions.DeclarativeAggregateFunction;
+import org.apache.flink.table.functions.FunctionDefinition;
+import org.apache.flink.table.functions.python.PythonFunction;
+import org.apache.flink.table.functions.python.PythonFunctionKind;
+import
org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction;
+import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction;
+import org.apache.flink.table.planner.functions.utils.AggSqlFunction;
+import org.apache.flink.table.planner.functions.utils.ScalarSqlFunction;
+import org.apache.flink.table.planner.functions.utils.TableSqlFunction;
+import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc;
+import
org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction;
+
+import org.apache.calcite.plan.hep.HepRelVertex;
+import org.apache.calcite.plan.volcano.RelSubset;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.type.RelDataTypeField;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexFieldAccess;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlAggFunction;
+import org.apache.calcite.sql.SqlKind;
+
+import java.util.List;
+import java.util.Optional;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+/** Utility for Python. */
+public class PythonUtil {
+
+ /**
+ * Checks whether it contains the specified kind of Python function call
in the specified node.
+ * If the parameter pythonFunctionKind is null, it will return true for
any kind of Python
+ * function.
+ *
+ * @param node the RexNode to check
+ * @param pythonFunctionKind the kind of the python function
+ * @return true if it contains the Python function call in the specified
node.
+ */
+ public static boolean containsPythonCall(RexNode node, PythonFunctionKind
pythonFunctionKind) {
+ FunctionFinder functionFinder =
+ new FunctionFinder(true,
Optional.ofNullable(pythonFunctionKind), true);
+ return node.accept(functionFinder);
+ }
+
+ public static boolean containsPythonCall(RexNode node) {
+ return containsPythonCall(node, null);
+ }
+
+ /**
+ * Checks whether it contains non-Python function call in the specified
node.
+ *
+ * @param node the RexNode to check
+ * @return true if it contains the non-Python function call in the
specified node.
+ */
+ public static boolean containsNonPythonCall(RexNode node) {
+ FunctionFinder functionFinder = new FunctionFinder(false,
Optional.empty(), true);
+ return node.accept(functionFinder);
+ }
+
+ /**
+ * Checks whether the specified node is the specified kind of Python
function call. If the
+ * parameter pythonFunctionKind is null, it will return true for any kind
of Python function.
+ *
+ * @param node the RexNode to check
+ * @param pythonFunctionKind the kind of the python function
+ * @return true if the specified node is a Python function call.
+ */
+ public static boolean isPythonCall(RexNode node, PythonFunctionKind
pythonFunctionKind) {
+ FunctionFinder functionFinder =
+ new FunctionFinder(true,
Optional.ofNullable(pythonFunctionKind), false);
+ return node.accept(functionFinder);
+ }
+
+ public static boolean isPythonCall(RexNode node) {
+ return isPythonCall(node, null);
+ }
+
+ /**
+ * Checks whether the specified node is a non-Python function call.
+ *
+ * @param node the RexNode to check
+ * @return true if the specified node is a non-Python function call.
+ */
+ public static boolean isNonPythonCall(RexNode node) {
+ FunctionFinder functionFinder = new FunctionFinder(false,
Optional.empty(), false);
+ return node.accept(functionFinder);
+ }
+
+ public static boolean isPythonAggregate(AggregateCall call) {
+ return isPythonAggregate(call, null);
+ }
+
+ /**
+ * Checks whether the specified aggregate is the specified kind of Python
function Aggregate.
+ *
+ * @param call the AggregateCall to check
+ * @param pythonFunctionKind the kind of the python function
+ * @return true if the specified call is a Python function Aggregate.
+ */
+ public static boolean isPythonAggregate(
+ AggregateCall call, PythonFunctionKind pythonFunctionKind) {
+ SqlAggFunction aggregation = call.getAggregation();
+ if (aggregation instanceof AggSqlFunction) {
+ return isPythonFunction(
+ ((AggSqlFunction) aggregation).aggregateFunction(),
pythonFunctionKind);
+ } else if (aggregation instanceof BridgingSqlAggFunction) {
+ return isPythonFunction(
+ ((BridgingSqlAggFunction) aggregation).getDefinition(),
pythonFunctionKind);
+ } else {
+ return false;
+ }
+ }
+
+ public static boolean isBuiltInAggregate(AggregateCall call) {
+ SqlAggFunction aggregation = call.getAggregation();
+ if (aggregation instanceof AggSqlFunction) {
+ AggSqlFunction aggSqlFunction = (AggSqlFunction) aggregation;
+ return aggSqlFunction.aggregateFunction() instanceof
BuiltInAggregateFunction;
+ } else if (aggregation instanceof BridgingSqlAggFunction) {
+ BridgingSqlAggFunction bridgingSqlAggFunction =
(BridgingSqlAggFunction) aggregation;
+ return bridgingSqlAggFunction.getDefinition() instanceof
DeclarativeAggregateFunction;
+ } else {
+ return true;
+ }
+ }
+
+ public static boolean takesRowAsInput(RexCall call) {
+ if (call.getOperator() instanceof ScalarSqlFunction) {
+ ScalarSqlFunction sfc = (ScalarSqlFunction) call.getOperator();
+ return ((PythonFunction) sfc.scalarFunction()).takesRowAsInput();
+ } else if (call.getOperator() instanceof TableSqlFunction) {
+ TableSqlFunction tfc = (TableSqlFunction) call.getOperator();
+ return ((PythonFunction) tfc.udtf()).takesRowAsInput();
+ } else if (call.getOperator() instanceof BridgingSqlFunction) {
+ BridgingSqlFunction bsf = (BridgingSqlFunction) call.getOperator();
+ return ((PythonFunction) bsf.getDefinition()).takesRowAsInput();
+ }
+ return false;
+ }
+
+ private static boolean isPythonFunction(
+ FunctionDefinition function, PythonFunctionKind
pythonFunctionKind) {
+ if (function instanceof PythonFunction) {
+ PythonFunction pythonFunction = (PythonFunction) function;
+ return pythonFunctionKind == null
+ || pythonFunction.getPythonFunctionKind() ==
pythonFunctionKind;
+ } else {
+ return false;
+ }
+ }
+
+ public static boolean isFlattenCalc(FlinkLogicalCalc calc) {
+ RelNode child = calc.getInput();
+ if (child instanceof RelSubset) {
+ child = ((RelSubset) child).getOriginal();
+ } else if (child instanceof HepRelVertex) {
+ child = ((HepRelVertex) child).getCurrentRel();
+ } else {
+ return false;
+ }
+ if (!(child instanceof FlinkLogicalCalc)) {
+ return false;
+ }
+
+ if (calc.getProgram().getCondition() != null) {
+ return false;
+ }
+
+ List<RelDataTypeField> inputFields =
calc.getProgram().getInputRowType().getFieldList();
+ if (inputFields.size() != 1 ||
!inputFields.get(0).getType().isStruct()) {
+ return false;
+ }
+
+ List<RexNode> projects =
+ calc.getProgram().getProjectList().stream()
+ .map(calc.getProgram()::expandLocalRef)
+ .collect(Collectors.toList());
+
+ if (inputFields.get(0).getType().getFieldCount() != projects.size()) {
+ return false;
+ }
+
+ return IntStream.range(0, projects.size())
+ .allMatch(idx -> projects.get(idx).accept(new
FieldReferenceDetector(idx)));
+ }
+
+ private static class FunctionFinder extends RexDefaultVisitor<Boolean> {
+ private final boolean findPythonFunction;
+ private final Optional<PythonFunctionKind> pythonFunctionKind;
+ private final boolean recursive;
+
+ /**
+ * Checks whether it contains the specified kind of function in a
RexNode.
+ *
+ * @param findPythonFunction true to find python function, false to
find non-python function
+ * @param pythonFunctionKind the kind of the python function
+ * @param recursive whether check the inputs
+ */
+ public FunctionFinder(
+ boolean findPythonFunction,
+ Optional<PythonFunctionKind> pythonFunctionKind,
+ boolean recursive) {
+ this.findPythonFunction = findPythonFunction;
+ this.pythonFunctionKind = pythonFunctionKind;
+ this.recursive = recursive;
+ }
+
+ /**
+ * Checks whether the specified rexCall is a python function call of
the specified kind.
+ *
+ * @param rexCall the RexCall to check.
+ * @return true if it is python function call of the specified kind.
+ */
+ private boolean isPythonRexCall(RexCall rexCall) {
+ if (rexCall.getOperator() instanceof ScalarSqlFunction) {
+ ScalarSqlFunction sfc = (ScalarSqlFunction)
rexCall.getOperator();
+ return isPythonFunction(sfc.scalarFunction());
+ } else if (rexCall.getOperator() instanceof TableSqlFunction) {
+ TableSqlFunction tfc = (TableSqlFunction)
rexCall.getOperator();
+ return isPythonFunction(tfc.udtf());
+ } else if (rexCall.getOperator() instanceof BridgingSqlFunction) {
+ BridgingSqlFunction bsf = (BridgingSqlFunction)
rexCall.getOperator();
+ return isPythonFunction(bsf.getDefinition());
+ } else {
+ return false;
+ }
+ }
+
+ private boolean isPythonFunction(FunctionDefinition
functionDefinition) {
+ if (functionDefinition instanceof PythonFunction) {
+ PythonFunction pythonFunction = (PythonFunction)
functionDefinition;
+ return !pythonFunctionKind.isPresent()
+ || pythonFunction.getPythonFunctionKind() ==
pythonFunctionKind.get();
+ } else {
+ return false;
+ }
+ }
+
+ @Override
+ public Boolean visitCall(RexCall call) {
+ return findPythonFunction == isPythonRexCall(call)
+ || (recursive
+ && call.getOperands().stream()
+ .anyMatch(operand ->
operand.accept(this)));
+ }
+
+ @Override
+ public Boolean visitFieldAccess(RexFieldAccess fieldAccess) {
+ return fieldAccess.getReferenceExpr().accept(this);
+ }
+
+ @Override
+ public Boolean visitNode(RexNode rexNode) {
+ return false;
+ }
+ }
+
+ /** Checks whether a rexNode is only a field reference of the given index.
*/
+ private static class FieldReferenceDetector extends
RexDefaultVisitor<Boolean> {
+ private final int idx;
+
+ public FieldReferenceDetector(int idx) {
+ this.idx = idx;
+ }
+
+ @Override
+ public Boolean visitNode(RexNode rexNode) {
+ return false;
+ }
+
+ @Override
+ public Boolean visitFieldAccess(RexFieldAccess fieldAccess) {
+ if (fieldAccess.getField().getIndex() != idx) {
+ return false;
+ }
+ RexNode expr = fieldAccess.getReferenceExpr();
+ if (expr instanceof RexInputRef) {
+ return ((RexInputRef) expr).getIndex() == 0;
+ } else {
+ return false;
+ }
+ }
+
+ @Override
+ public Boolean visitCall(RexCall call) {
+ if (call.getKind() == SqlKind.AS) {
+ return call.getOperands().get(0).accept(this);
+ } else {
+ return false;
+ }
+ }
+ }
+}
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/SplitPythonConditionFromCorrelateRule.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/SplitPythonConditionFromCorrelateRule.scala
deleted file mode 100644
index e6ccaa13baa..00000000000
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/SplitPythonConditionFromCorrelateRule.scala
+++ /dev/null
@@ -1,148 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.table.planner.plan.rules.logical
-
-import org.apache.flink.table.planner.plan.nodes.logical.{FlinkLogicalCalc,
FlinkLogicalCorrelate, FlinkLogicalRel, FlinkLogicalTableFunctionScan}
-import
org.apache.flink.table.planner.plan.rules.physical.stream.StreamPhysicalCorrelateRule
-import
org.apache.flink.table.planner.plan.utils.PythonUtil.{containsPythonCall,
isNonPythonCall}
-import org.apache.flink.table.planner.plan.utils.RexDefaultVisitor
-
-import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptUtil}
-import org.apache.calcite.plan.RelOptRule.{any, operand}
-import org.apache.calcite.rel.core.JoinRelType
-import org.apache.calcite.rex._
-
-import scala.collection.JavaConversions._
-import scala.collection.JavaConverters._
-
-/**
- * Rule will split a [[FlinkLogicalCalc]] which is the upstream of a
[[FlinkLogicalCorrelate]] and
- * contains Python Functions in condition into two [[FlinkLogicalCalc]]s. One
of the
- * [[FlinkLogicalCalc]] without python function condition is the upstream of
the
- * [[FlinkLogicalCorrelate]], but the other [[[FlinkLogicalCalc]] with python
function conditions is
- * the downstream of the [[FlinkLogicalCorrelate]]. Currently, only inner join
is supported.
- *
- * After this rule is applied, there will be no Python Functions in the
condition of the upstream
- * [[FlinkLogicalCalc]].
- */
-class SplitPythonConditionFromCorrelateRule
- extends RelOptRule(
- operand(
- classOf[FlinkLogicalCorrelate],
- operand(classOf[FlinkLogicalRel], any),
- operand(classOf[FlinkLogicalCalc], any)),
- "SplitPythonConditionFromCorrelateRule") {
- override def matches(call: RelOptRuleCall): Boolean = {
- val correlate: FlinkLogicalCorrelate =
call.rel(0).asInstanceOf[FlinkLogicalCorrelate]
- val right: FlinkLogicalCalc = call.rel(2).asInstanceOf[FlinkLogicalCalc]
- val joinType: JoinRelType = correlate.getJoinType
- val mergedCalc = StreamPhysicalCorrelateRule.getMergedCalc(right)
- val tableScan = StreamPhysicalCorrelateRule
- .getTableScan(mergedCalc)
- .asInstanceOf[FlinkLogicalTableFunctionScan]
- joinType == JoinRelType.INNER &&
- isNonPythonCall(tableScan.getCall) &&
- Option(mergedCalc.getProgram.getCondition)
- .map(mergedCalc.getProgram.expandLocalRef)
- .exists(containsPythonCall(_))
- }
-
- override def onMatch(call: RelOptRuleCall): Unit = {
- val correlate: FlinkLogicalCorrelate =
call.rel(0).asInstanceOf[FlinkLogicalCorrelate]
- val right: FlinkLogicalCalc = call.rel(2).asInstanceOf[FlinkLogicalCalc]
- val rexBuilder = call.builder().getRexBuilder
- val mergedCalc = StreamPhysicalCorrelateRule.getMergedCalc(right)
- val mergedCalcProgram = mergedCalc.getProgram
- val input = mergedCalc.getInput
-
- val correlateFilters = RelOptUtil
-
.conjunctions(mergedCalcProgram.expandLocalRef(mergedCalcProgram.getCondition))
-
- val remainingFilters = correlateFilters.filter(!containsPythonCall(_))
-
- val bottomCalcCondition = RexUtil.composeConjunction(rexBuilder,
remainingFilters)
-
- val newBottomCalc = new FlinkLogicalCalc(
- mergedCalc.getCluster,
- mergedCalc.getTraitSet,
- input,
- RexProgram.create(
- input.getRowType,
- mergedCalcProgram.getProjectList,
- bottomCalcCondition,
- mergedCalc.getRowType,
- rexBuilder))
-
- val newCorrelate = new FlinkLogicalCorrelate(
- correlate.getCluster,
- correlate.getTraitSet,
- correlate.getLeft,
- newBottomCalc,
- correlate.getCorrelationId,
- correlate.getRequiredColumns,
- correlate.getJoinType)
-
- val inputRefRewriter = new InputRefRewriter(
- correlate.getRowType.getFieldCount - mergedCalc.getRowType.getFieldCount)
-
- val pythonFilters = correlateFilters
- .filter(containsPythonCall(_))
- .map(_.accept(inputRefRewriter))
-
- val topCalcCondition = RexUtil.composeConjunction(rexBuilder,
pythonFilters)
-
- val rexProgram = new RexProgramBuilder(newCorrelate.getRowType,
rexBuilder).getProgram
- val newTopCalc = new FlinkLogicalCalc(
- newCorrelate.getCluster,
- newCorrelate.getTraitSet,
- newCorrelate,
- RexProgram.create(
- newCorrelate.getRowType,
- rexProgram.getExprList,
- topCalcCondition,
- newCorrelate.getRowType,
- rexBuilder))
-
- call.transformTo(newTopCalc)
- }
-}
-
-/**
- * Because the inputRef is from the upstream calc node of the correlate node,
so after the inputRef
- * is pushed to the downstream calc node of the correlate node, the inputRef
need to rewrite the
- * index.
- *
- * @param offset
- * the start offset of the inputRef in the downstream calc.
- */
-private class InputRefRewriter(offset: Int) extends RexDefaultVisitor[RexNode]
{
-
- override def visitInputRef(inputRef: RexInputRef): RexNode = {
- new RexInputRef(inputRef.getIndex + offset, inputRef.getType)
- }
-
- override def visitCall(call: RexCall): RexNode = {
- call.clone(call.getType, call.getOperands.asScala.map(_.accept(this)))
- }
-
- override def visitNode(rexNode: RexNode): RexNode = rexNode
-}
-
-object SplitPythonConditionFromCorrelateRule {
- val INSTANCE: SplitPythonConditionFromCorrelateRule = new
SplitPythonConditionFromCorrelateRule
-}
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/PythonUtil.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/PythonUtil.scala
deleted file mode 100644
index 8e9133300d7..00000000000
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/PythonUtil.scala
+++ /dev/null
@@ -1,243 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.table.planner.plan.utils
-
-import org.apache.flink.table.functions.{DeclarativeAggregateFunction,
FunctionDefinition}
-import org.apache.flink.table.functions.python.{PythonFunction,
PythonFunctionKind}
-import
org.apache.flink.table.planner.functions.bridging.{BridgingSqlAggFunction,
BridgingSqlFunction}
-import org.apache.flink.table.planner.functions.utils.{AggSqlFunction,
ScalarSqlFunction, TableSqlFunction}
-import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc
-import
org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction
-
-import org.apache.calcite.plan.hep.HepRelVertex
-import org.apache.calcite.plan.volcano.RelSubset
-import org.apache.calcite.rel.core.AggregateCall
-import org.apache.calcite.rex.{RexCall, RexFieldAccess, RexInputRef, RexNode}
-import org.apache.calcite.sql.SqlKind
-
-import scala.collection.JavaConversions._
-
-object PythonUtil {
-
- /**
- * Checks whether it contains the specified kind of Python function call in
the specified node. If
- * the parameter pythonFunctionKind is null, it will return true for any
kind of Python function.
- *
- * @param node
- * the RexNode to check
- * @param pythonFunctionKind
- * the kind of the python function
- * @return
- * true if it contains the Python function call in the specified node.
- */
- def containsPythonCall(node: RexNode, pythonFunctionKind: PythonFunctionKind
= null): Boolean =
- node.accept(new FunctionFinder(true, Option(pythonFunctionKind), true))
-
- /**
- * Checks whether it contains non-Python function call in the specified node.
- *
- * @param node
- * the RexNode to check
- * @return
- * true if it contains the non-Python function call in the specified node.
- */
- def containsNonPythonCall(node: RexNode): Boolean =
- node.accept(new FunctionFinder(false, None, true))
-
- /**
- * Checks whether the specified node is the specified kind of Python
function call. If the
- * parameter pythonFunctionKind is null, it will return true for any kind of
Python function.
- *
- * @param node
- * the RexNode to check
- * @param pythonFunctionKind
- * the kind of the python function
- * @return
- * true if the specified node is a Python function call.
- */
- def isPythonCall(node: RexNode, pythonFunctionKind: PythonFunctionKind =
null): Boolean =
- node.accept(new FunctionFinder(true, Option(pythonFunctionKind), false))
-
- /**
- * Checks whether the specified node is a non-Python function call.
- *
- * @param node
- * the RexNode to check
- * @return
- * true if the specified node is a non-Python function call.
- */
- def isNonPythonCall(node: RexNode): Boolean = node.accept(new
FunctionFinder(false, None, false))
-
- /**
- * Checks whether the specified aggregate is the specified kind of Python
function Aggregate.
- *
- * @param call
- * the AggregateCall to check
- * @param pythonFunctionKind
- * the kind of the python function
- * @return
- * true if the specified call is a Python function Aggregate.
- */
- def isPythonAggregate(
- call: AggregateCall,
- pythonFunctionKind: PythonFunctionKind = null): Boolean = {
- val aggregation = call.getAggregation
- aggregation match {
- case function: AggSqlFunction =>
- isPythonFunction(function.aggregateFunction, pythonFunctionKind)
- case function: BridgingSqlAggFunction =>
- isPythonFunction(function.getDefinition, pythonFunctionKind)
- case _ => false
- }
- }
-
- def isBuiltInAggregate(call: AggregateCall): Boolean = {
- val aggregation = call.getAggregation
- aggregation match {
- case function: AggSqlFunction =>
- function.aggregateFunction.isInstanceOf[BuiltInAggregateFunction[_, _]]
- case function: BridgingSqlAggFunction =>
- function.getDefinition.isInstanceOf[DeclarativeAggregateFunction]
- case _ => true
- }
- }
-
- def takesRowAsInput(call: RexCall): Boolean = {
- (call.getOperator match {
- case sfc: ScalarSqlFunction => sfc.scalarFunction
- case tfc: TableSqlFunction => tfc.udtf
- case bsf: BridgingSqlFunction => bsf.getDefinition
- }).asInstanceOf[PythonFunction].takesRowAsInput()
- }
-
- private[this] def isPythonFunction(
- function: FunctionDefinition,
- pythonFunctionKind: PythonFunctionKind): Boolean = {
- function match {
- case pythonFunction: PythonFunction =>
- pythonFunctionKind == null || pythonFunction.getPythonFunctionKind ==
pythonFunctionKind
- case _ => false
- }
- }
-
- def isFlattenCalc(calc: FlinkLogicalCalc): Boolean = {
- val child = calc.getInput match {
- case relSubset: RelSubset => relSubset.getOriginal
- case hepRelVertex: HepRelVertex => hepRelVertex.getCurrentRel
- }
- if (!child.isInstanceOf[FlinkLogicalCalc]) {
- return false
- }
-
- if (calc.getProgram.getCondition != null) {
- return false
- }
-
- val inputFields = calc.getProgram.getInputRowType.getFieldList
- if (inputFields.size != 1 || !inputFields.get(0).getValue.isStruct) {
- return false
- }
-
- val projects =
calc.getProgram.getProjectList.map(calc.getProgram.expandLocalRef)
-
- if (inputFields.get(0).getValue.getFieldList.size() != projects.size) {
- return false
- }
-
- projects.zipWithIndex.forall {
- case (project: RexNode, idx: Int) => project.accept(new
FieldReferenceDetector(idx))
- }
- }
-
- /**
- * Checks whether it contains the specified kind of function in a RexNode.
- *
- * @param findPythonFunction
- * true to find python function, false to find non-python function
- * @param pythonFunctionKind
- * the kind of the python function
- * @param recursive
- * whether check the inputs
- */
- private class FunctionFinder(
- findPythonFunction: Boolean,
- pythonFunctionKind: Option[PythonFunctionKind],
- recursive: Boolean)
- extends RexDefaultVisitor[Boolean] {
-
- /**
- * Checks whether the specified rexCall is a python function call of the
specified kind.
- *
- * @param rexCall
- * the RexCall to check.
- * @return
- * true if it is python function call of the specified kind.
- */
- private def isPythonRexCall(rexCall: RexCall): Boolean =
- rexCall.getOperator match {
- case sfc: ScalarSqlFunction => isPythonFunction(sfc.scalarFunction)
- case tfc: TableSqlFunction => isPythonFunction(tfc.udtf)
- case bsf: BridgingSqlFunction => isPythonFunction(bsf.getDefinition)
- case _ => false
- }
-
- private def isPythonFunction(functionDefinition: FunctionDefinition):
Boolean = {
- functionDefinition.isInstanceOf[PythonFunction] &&
- (pythonFunctionKind.isEmpty ||
- functionDefinition.asInstanceOf[PythonFunction].getPythonFunctionKind
==
- pythonFunctionKind.get)
- }
-
- override def visitCall(call: RexCall): Boolean = {
- findPythonFunction == isPythonRexCall(call) ||
- (recursive && call.getOperands.exists(_.accept(this)))
- }
-
- override def visitFieldAccess(fieldAccess: RexFieldAccess): Boolean = {
- fieldAccess.getReferenceExpr.accept(this)
- }
-
- override def visitNode(rexNode: RexNode): Boolean = false
- }
-
- /** Checks whether a rexNode is only a field reference of the given index. */
- private class FieldReferenceDetector(idx: Int) extends
RexDefaultVisitor[Boolean] {
-
- override def visitNode(rexNode: RexNode): Boolean = false
-
- override def visitFieldAccess(fieldAccess: RexFieldAccess): Boolean = {
- if (fieldAccess.getField.getIndex != idx) {
- return false
- }
-
- val expr: RexNode = fieldAccess.getReferenceExpr
- expr match {
- case ref: RexInputRef => ref.getIndex == 0
- case _ => false
- }
- }
-
- override def visitCall(call: RexCall): Boolean = {
- if (call.getKind == SqlKind.AS) {
- call.getOperands.get(0).accept(this)
- } else {
- false
- }
- }
- }
-}