This is an automated email from the ASF dual-hosted git repository.

snuyanzin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new a1b9e42df9e [FLINK-35751][table] Migrate 
SplitPythonConditionFromCorrelateRule to java
a1b9e42df9e is described below

commit a1b9e42df9e7b6d3a398d1038063e6b7396891fe
Author: Jacky Lau <[email protected]>
AuthorDate: Mon Nov 25 01:56:56 2024 +0800

    [FLINK-35751][table] Migrate SplitPythonConditionFromCorrelateRule to java
---
 .../logical/CalcPythonCorrelateTransposeRule.java  |   4 +-
 .../SplitPythonConditionFromCorrelateRule.java     | 223 +++++++++++++++
 .../flink/table/planner/plan/utils/PythonUtil.java | 316 +++++++++++++++++++++
 .../SplitPythonConditionFromCorrelateRule.scala    | 148 ----------
 .../table/planner/plan/utils/PythonUtil.scala      | 243 ----------------
 5 files changed, 541 insertions(+), 393 deletions(-)

diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRule.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRule.java
index 600ccc951f1..0f12cc82328 100644
--- 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRule.java
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/CalcPythonCorrelateTransposeRule.java
@@ -78,8 +78,8 @@ public class CalcPythonCorrelateTransposeRule extends 
RelOptRule {
                 StreamPhysicalCorrelateRule.getTableScan(mergedCalc);
         RexProgram mergedCalcProgram = mergedCalc.getProgram();
 
-        InputRefRewriter inputRefRewriter =
-                new InputRefRewriter(
+        SplitPythonConditionFromCorrelateRule.InputRefRewriter 
inputRefRewriter =
+                new SplitPythonConditionFromCorrelateRule.InputRefRewriter(
                         correlate.getRowType().getFieldCount()
                                 - mergedCalc.getRowType().getFieldCount());
         List<RexNode> correlateFilters =
diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/SplitPythonConditionFromCorrelateRule.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/SplitPythonConditionFromCorrelateRule.java
new file mode 100644
index 00000000000..bcfe4a43660
--- /dev/null
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/SplitPythonConditionFromCorrelateRule.java
@@ -0,0 +1,223 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.logical;
+
+import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc;
+import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCorrelate;
+import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalRel;
+import 
org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan;
+import 
org.apache.flink.table.planner.plan.rules.physical.stream.StreamPhysicalCorrelateRule;
+import org.apache.flink.table.planner.plan.utils.RexDefaultVisitor;
+
+import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelOptUtil;
+import org.apache.calcite.plan.RelRule;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.JoinRelType;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexProgram;
+import org.apache.calcite.rex.RexProgramBuilder;
+import org.apache.calcite.rex.RexUtil;
+import org.immutables.value.Value;
+
+import java.util.List;
+import java.util.stream.Collectors;
+
+import static 
org.apache.flink.table.planner.plan.utils.PythonUtil.containsPythonCall;
+import static 
org.apache.flink.table.planner.plan.utils.PythonUtil.isNonPythonCall;
+
+/**
+ * Rule will split a {@link FlinkLogicalCalc} which is the upstream of a {@link
+ * FlinkLogicalCorrelate} and contains Python Functions in condition into two 
{@link
+ * FlinkLogicalCalc}s. One of the {@link FlinkLogicalCalc} without python 
function condition is the
+ * upstream of the {@link FlinkLogicalCorrelate}, but the other {@link 
FlinkLogicalCalc} with python
+ * function conditions is the downstream of the {@link FlinkLogicalCorrelate}. 
Currently, only inner
+ * join is supported.
+ *
+ * <p>After this rule is applied, there will be no Python Functions in the 
condition of the upstream
+ * {@link FlinkLogicalCalc}.
+ */
[email protected]
+public class SplitPythonConditionFromCorrelateRule
+        extends RelRule<
+                
SplitPythonConditionFromCorrelateRule.SplitPythonConditionFromCorrelateRuleConfig>
 {
+
+    public static final SplitPythonConditionFromCorrelateRule INSTANCE =
+            
SplitPythonConditionFromCorrelateRule.SplitPythonConditionFromCorrelateRuleConfig
+                    .DEFAULT
+                    .toRule();
+
+    private SplitPythonConditionFromCorrelateRule(
+            SplitPythonConditionFromCorrelateRuleConfig config) {
+        super(config);
+    }
+
+    @Override
+    public boolean matches(RelOptRuleCall call) {
+        FlinkLogicalCorrelate correlate = call.rel(0);
+        FlinkLogicalCalc right = call.rel(2);
+        JoinRelType joinType = correlate.getJoinType();
+        FlinkLogicalCalc mergedCalc = 
StreamPhysicalCorrelateRule.getMergedCalc(right);
+        FlinkLogicalTableFunctionScan tableScan =
+                StreamPhysicalCorrelateRule.getTableScan(mergedCalc);
+
+        return joinType == JoinRelType.INNER
+                && isNonPythonCall(tableScan.getCall())
+                && mergedCalc.getProgram() != null
+                && mergedCalc.getProgram().getCondition() != null
+                && containsPythonCall(
+                        mergedCalc
+                                .getProgram()
+                                
.expandLocalRef(mergedCalc.getProgram().getCondition()),
+                        null);
+    }
+
+    @Override
+    public void onMatch(RelOptRuleCall call) {
+        FlinkLogicalCorrelate correlate = call.rel(0);
+        FlinkLogicalCalc right = call.rel(2);
+        RexBuilder rexBuilder = call.builder().getRexBuilder();
+        FlinkLogicalCalc mergedCalc = 
StreamPhysicalCorrelateRule.getMergedCalc(right);
+        RexProgram mergedCalcProgram = mergedCalc.getProgram();
+        RelNode input = mergedCalc.getInput();
+
+        List<RexNode> correlateFilters =
+                RelOptUtil.conjunctions(
+                        
mergedCalcProgram.expandLocalRef(mergedCalcProgram.getCondition()));
+
+        List<RexNode> remainingFilters =
+                correlateFilters.stream()
+                        .filter(filter -> !containsPythonCall(filter))
+                        .collect(Collectors.toList());
+
+        RexNode bottomCalcCondition = RexUtil.composeConjunction(rexBuilder, 
remainingFilters);
+
+        FlinkLogicalCalc newBottomCalc =
+                new FlinkLogicalCalc(
+                        mergedCalc.getCluster(),
+                        mergedCalc.getTraitSet(),
+                        input,
+                        RexProgram.create(
+                                input.getRowType(),
+                                mergedCalcProgram.getProjectList(),
+                                bottomCalcCondition,
+                                mergedCalc.getRowType(),
+                                rexBuilder));
+
+        FlinkLogicalCorrelate newCorrelate =
+                new FlinkLogicalCorrelate(
+                        correlate.getCluster(),
+                        correlate.getTraitSet(),
+                        correlate.getLeft(),
+                        newBottomCalc,
+                        correlate.getCorrelationId(),
+                        correlate.getRequiredColumns(),
+                        correlate.getJoinType());
+
+        InputRefRewriter inputRefRewriter =
+                new InputRefRewriter(
+                        correlate.getRowType().getFieldCount()
+                                - mergedCalc.getRowType().getFieldCount());
+
+        List<RexNode> pythonFilters =
+                correlateFilters.stream()
+                        .filter(filter -> containsPythonCall(filter))
+                        .map(filter -> filter.accept(inputRefRewriter))
+                        .collect(Collectors.toList());
+
+        RexNode topCalcCondition = RexUtil.composeConjunction(rexBuilder, 
pythonFilters);
+
+        RexProgram rexProgram =
+                new RexProgramBuilder(newCorrelate.getRowType(), 
rexBuilder).getProgram();
+        FlinkLogicalCalc newTopCalc =
+                new FlinkLogicalCalc(
+                        newCorrelate.getCluster(),
+                        newCorrelate.getTraitSet(),
+                        newCorrelate,
+                        RexProgram.create(
+                                newCorrelate.getRowType(),
+                                rexProgram.getExprList(),
+                                topCalcCondition,
+                                newCorrelate.getRowType(),
+                                rexBuilder));
+
+        call.transformTo(newTopCalc);
+    }
+
+    /** Rule configuration. */
+    @Value.Immutable(singleton = false)
+    public interface SplitPythonConditionFromCorrelateRuleConfig extends 
RelRule.Config {
+        
SplitPythonConditionFromCorrelateRule.SplitPythonConditionFromCorrelateRuleConfig
 DEFAULT =
+                ImmutableSplitPythonConditionFromCorrelateRule
+                        .SplitPythonConditionFromCorrelateRuleConfig.builder()
+                        .build()
+                        .withOperandSupplier(
+                                b0 ->
+                                        b0.operand(FlinkLogicalCorrelate.class)
+                                                .inputs(
+                                                        b1 ->
+                                                                
b1.operand(FlinkLogicalRel.class)
+                                                                        
.anyInputs(),
+                                                        b2 ->
+                                                                
b2.operand(FlinkLogicalCalc.class)
+                                                                        
.anyInputs()))
+                        
.withDescription("SplitPythonConditionFromCorrelateRule");
+
+        @Override
+        default SplitPythonConditionFromCorrelateRule toRule() {
+            return new SplitPythonConditionFromCorrelateRule(this);
+        }
+    }
+
+    /**
+     * Because the inputRef is from the upstream calc node of the correlate 
node, so after the
+     * inputRef is pushed to the downstream calc node of the correlate node, 
the inputRef need to
+     * rewrite the index.
+     */
+    static class InputRefRewriter extends RexDefaultVisitor<RexNode> {
+        private final int offset;
+
+        /** @param offset the start offset of the inputRef in the downstream 
calc. */
+        public InputRefRewriter(int offset) {
+            this.offset = offset;
+        }
+
+        @Override
+        public RexNode visitInputRef(RexInputRef inputRef) {
+            return new RexInputRef(inputRef.getIndex() + offset, 
inputRef.getType());
+        }
+
+        @Override
+        public RexNode visitCall(RexCall call) {
+            return call.clone(
+                    call.getType(),
+                    call.getOperands().stream()
+                            .map(o -> o.accept(this))
+                            .collect(Collectors.toList()));
+        }
+
+        @Override
+        public RexNode visitNode(RexNode rexNode) {
+            return rexNode;
+        }
+    }
+}
diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/PythonUtil.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/PythonUtil.java
new file mode 100644
index 00000000000..65efb05f430
--- /dev/null
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/utils/PythonUtil.java
@@ -0,0 +1,316 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.utils;
+
+import org.apache.flink.table.functions.DeclarativeAggregateFunction;
+import org.apache.flink.table.functions.FunctionDefinition;
+import org.apache.flink.table.functions.python.PythonFunction;
+import org.apache.flink.table.functions.python.PythonFunctionKind;
+import 
org.apache.flink.table.planner.functions.bridging.BridgingSqlAggFunction;
+import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction;
+import org.apache.flink.table.planner.functions.utils.AggSqlFunction;
+import org.apache.flink.table.planner.functions.utils.ScalarSqlFunction;
+import org.apache.flink.table.planner.functions.utils.TableSqlFunction;
+import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc;
+import 
org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction;
+
+import org.apache.calcite.plan.hep.HepRelVertex;
+import org.apache.calcite.plan.volcano.RelSubset;
+import org.apache.calcite.rel.RelNode;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.type.RelDataTypeField;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexFieldAccess;
+import org.apache.calcite.rex.RexInputRef;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlAggFunction;
+import org.apache.calcite.sql.SqlKind;
+
+import java.util.List;
+import java.util.Optional;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+/** Utility for Python. */
+public class PythonUtil {
+
+    /**
+     * Checks whether it contains the specified kind of Python function call 
in the specified node.
+     * If the parameter pythonFunctionKind is null, it will return true for 
any kind of Python
+     * function.
+     *
+     * @param node the RexNode to check
+     * @param pythonFunctionKind the kind of the python function
+     * @return true if it contains the Python function call in the specified 
node.
+     */
+    public static boolean containsPythonCall(RexNode node, PythonFunctionKind 
pythonFunctionKind) {
+        FunctionFinder functionFinder =
+                new FunctionFinder(true, 
Optional.ofNullable(pythonFunctionKind), true);
+        return node.accept(functionFinder);
+    }
+
+    public static boolean containsPythonCall(RexNode node) {
+        return containsPythonCall(node, null);
+    }
+
+    /**
+     * Checks whether it contains non-Python function call in the specified 
node.
+     *
+     * @param node the RexNode to check
+     * @return true if it contains the non-Python function call in the 
specified node.
+     */
+    public static boolean containsNonPythonCall(RexNode node) {
+        FunctionFinder functionFinder = new FunctionFinder(false, 
Optional.empty(), true);
+        return node.accept(functionFinder);
+    }
+
+    /**
+     * Checks whether the specified node is the specified kind of Python 
function call. If the
+     * parameter pythonFunctionKind is null, it will return true for any kind 
of Python function.
+     *
+     * @param node the RexNode to check
+     * @param pythonFunctionKind the kind of the python function
+     * @return true if the specified node is a Python function call.
+     */
+    public static boolean isPythonCall(RexNode node, PythonFunctionKind 
pythonFunctionKind) {
+        FunctionFinder functionFinder =
+                new FunctionFinder(true, 
Optional.ofNullable(pythonFunctionKind), false);
+        return node.accept(functionFinder);
+    }
+
+    public static boolean isPythonCall(RexNode node) {
+        return isPythonCall(node, null);
+    }
+
+    /**
+     * Checks whether the specified node is a non-Python function call.
+     *
+     * @param node the RexNode to check
+     * @return true if the specified node is a non-Python function call.
+     */
+    public static boolean isNonPythonCall(RexNode node) {
+        FunctionFinder functionFinder = new FunctionFinder(false, 
Optional.empty(), false);
+        return node.accept(functionFinder);
+    }
+
+    public static boolean isPythonAggregate(AggregateCall call) {
+        return isPythonAggregate(call, null);
+    }
+
+    /**
+     * Checks whether the specified aggregate is the specified kind of Python 
function Aggregate.
+     *
+     * @param call the AggregateCall to check
+     * @param pythonFunctionKind the kind of the python function
+     * @return true if the specified call is a Python function Aggregate.
+     */
+    public static boolean isPythonAggregate(
+            AggregateCall call, PythonFunctionKind pythonFunctionKind) {
+        SqlAggFunction aggregation = call.getAggregation();
+        if (aggregation instanceof AggSqlFunction) {
+            return isPythonFunction(
+                    ((AggSqlFunction) aggregation).aggregateFunction(), 
pythonFunctionKind);
+        } else if (aggregation instanceof BridgingSqlAggFunction) {
+            return isPythonFunction(
+                    ((BridgingSqlAggFunction) aggregation).getDefinition(), 
pythonFunctionKind);
+        } else {
+            return false;
+        }
+    }
+
+    public static boolean isBuiltInAggregate(AggregateCall call) {
+        SqlAggFunction aggregation = call.getAggregation();
+        if (aggregation instanceof AggSqlFunction) {
+            AggSqlFunction aggSqlFunction = (AggSqlFunction) aggregation;
+            return aggSqlFunction.aggregateFunction() instanceof 
BuiltInAggregateFunction;
+        } else if (aggregation instanceof BridgingSqlAggFunction) {
+            BridgingSqlAggFunction bridgingSqlAggFunction = 
(BridgingSqlAggFunction) aggregation;
+            return bridgingSqlAggFunction.getDefinition() instanceof 
DeclarativeAggregateFunction;
+        } else {
+            return true;
+        }
+    }
+
+    public static boolean takesRowAsInput(RexCall call) {
+        if (call.getOperator() instanceof ScalarSqlFunction) {
+            ScalarSqlFunction sfc = (ScalarSqlFunction) call.getOperator();
+            return ((PythonFunction) sfc.scalarFunction()).takesRowAsInput();
+        } else if (call.getOperator() instanceof TableSqlFunction) {
+            TableSqlFunction tfc = (TableSqlFunction) call.getOperator();
+            return ((PythonFunction) tfc.udtf()).takesRowAsInput();
+        } else if (call.getOperator() instanceof BridgingSqlFunction) {
+            BridgingSqlFunction bsf = (BridgingSqlFunction) call.getOperator();
+            return ((PythonFunction) bsf.getDefinition()).takesRowAsInput();
+        }
+        return false;
+    }
+
+    private static boolean isPythonFunction(
+            FunctionDefinition function, PythonFunctionKind 
pythonFunctionKind) {
+        if (function instanceof PythonFunction) {
+            PythonFunction pythonFunction = (PythonFunction) function;
+            return pythonFunctionKind == null
+                    || pythonFunction.getPythonFunctionKind() == 
pythonFunctionKind;
+        } else {
+            return false;
+        }
+    }
+
+    public static boolean isFlattenCalc(FlinkLogicalCalc calc) {
+        RelNode child = calc.getInput();
+        if (child instanceof RelSubset) {
+            child = ((RelSubset) child).getOriginal();
+        } else if (child instanceof HepRelVertex) {
+            child = ((HepRelVertex) child).getCurrentRel();
+        } else {
+            return false;
+        }
+        if (!(child instanceof FlinkLogicalCalc)) {
+            return false;
+        }
+
+        if (calc.getProgram().getCondition() != null) {
+            return false;
+        }
+
+        List<RelDataTypeField> inputFields = 
calc.getProgram().getInputRowType().getFieldList();
+        if (inputFields.size() != 1 || 
!inputFields.get(0).getType().isStruct()) {
+            return false;
+        }
+
+        List<RexNode> projects =
+                calc.getProgram().getProjectList().stream()
+                        .map(calc.getProgram()::expandLocalRef)
+                        .collect(Collectors.toList());
+
+        if (inputFields.get(0).getType().getFieldCount() != projects.size()) {
+            return false;
+        }
+
+        return IntStream.range(0, projects.size())
+                .allMatch(idx -> projects.get(idx).accept(new 
FieldReferenceDetector(idx)));
+    }
+
+    private static class FunctionFinder extends RexDefaultVisitor<Boolean> {
+        private final boolean findPythonFunction;
+        private final Optional<PythonFunctionKind> pythonFunctionKind;
+        private final boolean recursive;
+
+        /**
+         * Checks whether it contains the specified kind of function in a 
RexNode.
+         *
+         * @param findPythonFunction true to find python function, false to 
find non-python function
+         * @param pythonFunctionKind the kind of the python function
+         * @param recursive whether check the inputs
+         */
+        public FunctionFinder(
+                boolean findPythonFunction,
+                Optional<PythonFunctionKind> pythonFunctionKind,
+                boolean recursive) {
+            this.findPythonFunction = findPythonFunction;
+            this.pythonFunctionKind = pythonFunctionKind;
+            this.recursive = recursive;
+        }
+
+        /**
+         * Checks whether the specified rexCall is a python function call of 
the specified kind.
+         *
+         * @param rexCall the RexCall to check.
+         * @return true if it is python function call of the specified kind.
+         */
+        private boolean isPythonRexCall(RexCall rexCall) {
+            if (rexCall.getOperator() instanceof ScalarSqlFunction) {
+                ScalarSqlFunction sfc = (ScalarSqlFunction) 
rexCall.getOperator();
+                return isPythonFunction(sfc.scalarFunction());
+            } else if (rexCall.getOperator() instanceof TableSqlFunction) {
+                TableSqlFunction tfc = (TableSqlFunction) 
rexCall.getOperator();
+                return isPythonFunction(tfc.udtf());
+            } else if (rexCall.getOperator() instanceof BridgingSqlFunction) {
+                BridgingSqlFunction bsf = (BridgingSqlFunction) 
rexCall.getOperator();
+                return isPythonFunction(bsf.getDefinition());
+            } else {
+                return false;
+            }
+        }
+
+        private boolean isPythonFunction(FunctionDefinition 
functionDefinition) {
+            if (functionDefinition instanceof PythonFunction) {
+                PythonFunction pythonFunction = (PythonFunction) 
functionDefinition;
+                return !pythonFunctionKind.isPresent()
+                        || pythonFunction.getPythonFunctionKind() == 
pythonFunctionKind.get();
+            } else {
+                return false;
+            }
+        }
+
+        @Override
+        public Boolean visitCall(RexCall call) {
+            return findPythonFunction == isPythonRexCall(call)
+                    || (recursive
+                            && call.getOperands().stream()
+                                    .anyMatch(operand -> 
operand.accept(this)));
+        }
+
+        @Override
+        public Boolean visitFieldAccess(RexFieldAccess fieldAccess) {
+            return fieldAccess.getReferenceExpr().accept(this);
+        }
+
+        @Override
+        public Boolean visitNode(RexNode rexNode) {
+            return false;
+        }
+    }
+
+    /** Checks whether a rexNode is only a field reference of the given index. 
*/
+    private static class FieldReferenceDetector extends 
RexDefaultVisitor<Boolean> {
+        private final int idx;
+
+        public FieldReferenceDetector(int idx) {
+            this.idx = idx;
+        }
+
+        @Override
+        public Boolean visitNode(RexNode rexNode) {
+            return false;
+        }
+
+        @Override
+        public Boolean visitFieldAccess(RexFieldAccess fieldAccess) {
+            if (fieldAccess.getField().getIndex() != idx) {
+                return false;
+            }
+            RexNode expr = fieldAccess.getReferenceExpr();
+            if (expr instanceof RexInputRef) {
+                return ((RexInputRef) expr).getIndex() == 0;
+            } else {
+                return false;
+            }
+        }
+
+        @Override
+        public Boolean visitCall(RexCall call) {
+            if (call.getKind() == SqlKind.AS) {
+                return call.getOperands().get(0).accept(this);
+            } else {
+                return false;
+            }
+        }
+    }
+}
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/SplitPythonConditionFromCorrelateRule.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/SplitPythonConditionFromCorrelateRule.scala
deleted file mode 100644
index e6ccaa13baa..00000000000
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/SplitPythonConditionFromCorrelateRule.scala
+++ /dev/null
@@ -1,148 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.table.planner.plan.rules.logical
-
-import org.apache.flink.table.planner.plan.nodes.logical.{FlinkLogicalCalc, 
FlinkLogicalCorrelate, FlinkLogicalRel, FlinkLogicalTableFunctionScan}
-import 
org.apache.flink.table.planner.plan.rules.physical.stream.StreamPhysicalCorrelateRule
-import 
org.apache.flink.table.planner.plan.utils.PythonUtil.{containsPythonCall, 
isNonPythonCall}
-import org.apache.flink.table.planner.plan.utils.RexDefaultVisitor
-
-import org.apache.calcite.plan.{RelOptRule, RelOptRuleCall, RelOptUtil}
-import org.apache.calcite.plan.RelOptRule.{any, operand}
-import org.apache.calcite.rel.core.JoinRelType
-import org.apache.calcite.rex._
-
-import scala.collection.JavaConversions._
-import scala.collection.JavaConverters._
-
-/**
- * Rule will split a [[FlinkLogicalCalc]] which is the upstream of a 
[[FlinkLogicalCorrelate]] and
- * contains Python Functions in condition into two [[FlinkLogicalCalc]]s. One 
of the
- * [[FlinkLogicalCalc]] without python function condition is the upstream of 
the
- * [[FlinkLogicalCorrelate]], but the other [[[FlinkLogicalCalc]] with python 
function conditions is
- * the downstream of the [[FlinkLogicalCorrelate]]. Currently, only inner join 
is supported.
- *
- * After this rule is applied, there will be no Python Functions in the 
condition of the upstream
- * [[FlinkLogicalCalc]].
- */
-class SplitPythonConditionFromCorrelateRule
-  extends RelOptRule(
-    operand(
-      classOf[FlinkLogicalCorrelate],
-      operand(classOf[FlinkLogicalRel], any),
-      operand(classOf[FlinkLogicalCalc], any)),
-    "SplitPythonConditionFromCorrelateRule") {
-  override def matches(call: RelOptRuleCall): Boolean = {
-    val correlate: FlinkLogicalCorrelate = 
call.rel(0).asInstanceOf[FlinkLogicalCorrelate]
-    val right: FlinkLogicalCalc = call.rel(2).asInstanceOf[FlinkLogicalCalc]
-    val joinType: JoinRelType = correlate.getJoinType
-    val mergedCalc = StreamPhysicalCorrelateRule.getMergedCalc(right)
-    val tableScan = StreamPhysicalCorrelateRule
-      .getTableScan(mergedCalc)
-      .asInstanceOf[FlinkLogicalTableFunctionScan]
-    joinType == JoinRelType.INNER &&
-    isNonPythonCall(tableScan.getCall) &&
-    Option(mergedCalc.getProgram.getCondition)
-      .map(mergedCalc.getProgram.expandLocalRef)
-      .exists(containsPythonCall(_))
-  }
-
-  override def onMatch(call: RelOptRuleCall): Unit = {
-    val correlate: FlinkLogicalCorrelate = 
call.rel(0).asInstanceOf[FlinkLogicalCorrelate]
-    val right: FlinkLogicalCalc = call.rel(2).asInstanceOf[FlinkLogicalCalc]
-    val rexBuilder = call.builder().getRexBuilder
-    val mergedCalc = StreamPhysicalCorrelateRule.getMergedCalc(right)
-    val mergedCalcProgram = mergedCalc.getProgram
-    val input = mergedCalc.getInput
-
-    val correlateFilters = RelOptUtil
-      
.conjunctions(mergedCalcProgram.expandLocalRef(mergedCalcProgram.getCondition))
-
-    val remainingFilters = correlateFilters.filter(!containsPythonCall(_))
-
-    val bottomCalcCondition = RexUtil.composeConjunction(rexBuilder, 
remainingFilters)
-
-    val newBottomCalc = new FlinkLogicalCalc(
-      mergedCalc.getCluster,
-      mergedCalc.getTraitSet,
-      input,
-      RexProgram.create(
-        input.getRowType,
-        mergedCalcProgram.getProjectList,
-        bottomCalcCondition,
-        mergedCalc.getRowType,
-        rexBuilder))
-
-    val newCorrelate = new FlinkLogicalCorrelate(
-      correlate.getCluster,
-      correlate.getTraitSet,
-      correlate.getLeft,
-      newBottomCalc,
-      correlate.getCorrelationId,
-      correlate.getRequiredColumns,
-      correlate.getJoinType)
-
-    val inputRefRewriter = new InputRefRewriter(
-      correlate.getRowType.getFieldCount - mergedCalc.getRowType.getFieldCount)
-
-    val pythonFilters = correlateFilters
-      .filter(containsPythonCall(_))
-      .map(_.accept(inputRefRewriter))
-
-    val topCalcCondition = RexUtil.composeConjunction(rexBuilder, 
pythonFilters)
-
-    val rexProgram = new RexProgramBuilder(newCorrelate.getRowType, 
rexBuilder).getProgram
-    val newTopCalc = new FlinkLogicalCalc(
-      newCorrelate.getCluster,
-      newCorrelate.getTraitSet,
-      newCorrelate,
-      RexProgram.create(
-        newCorrelate.getRowType,
-        rexProgram.getExprList,
-        topCalcCondition,
-        newCorrelate.getRowType,
-        rexBuilder))
-
-    call.transformTo(newTopCalc)
-  }
-}
-
-/**
- * Because the inputRef is from the upstream calc node of the correlate node, 
so after the inputRef
- * is pushed to the downstream calc node of the correlate node, the inputRef 
need to rewrite the
- * index.
- *
- * @param offset
- *   the start offset of the inputRef in the downstream calc.
- */
-private class InputRefRewriter(offset: Int) extends RexDefaultVisitor[RexNode] 
{
-
-  override def visitInputRef(inputRef: RexInputRef): RexNode = {
-    new RexInputRef(inputRef.getIndex + offset, inputRef.getType)
-  }
-
-  override def visitCall(call: RexCall): RexNode = {
-    call.clone(call.getType, call.getOperands.asScala.map(_.accept(this)))
-  }
-
-  override def visitNode(rexNode: RexNode): RexNode = rexNode
-}
-
-object SplitPythonConditionFromCorrelateRule {
-  val INSTANCE: SplitPythonConditionFromCorrelateRule = new 
SplitPythonConditionFromCorrelateRule
-}
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/PythonUtil.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/PythonUtil.scala
deleted file mode 100644
index 8e9133300d7..00000000000
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/PythonUtil.scala
+++ /dev/null
@@ -1,243 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.flink.table.planner.plan.utils
-
-import org.apache.flink.table.functions.{DeclarativeAggregateFunction, 
FunctionDefinition}
-import org.apache.flink.table.functions.python.{PythonFunction, 
PythonFunctionKind}
-import 
org.apache.flink.table.planner.functions.bridging.{BridgingSqlAggFunction, 
BridgingSqlFunction}
-import org.apache.flink.table.planner.functions.utils.{AggSqlFunction, 
ScalarSqlFunction, TableSqlFunction}
-import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc
-import 
org.apache.flink.table.runtime.functions.aggregate.BuiltInAggregateFunction
-
-import org.apache.calcite.plan.hep.HepRelVertex
-import org.apache.calcite.plan.volcano.RelSubset
-import org.apache.calcite.rel.core.AggregateCall
-import org.apache.calcite.rex.{RexCall, RexFieldAccess, RexInputRef, RexNode}
-import org.apache.calcite.sql.SqlKind
-
-import scala.collection.JavaConversions._
-
-object PythonUtil {
-
-  /**
-   * Checks whether it contains the specified kind of Python function call in 
the specified node. If
-   * the parameter pythonFunctionKind is null, it will return true for any 
kind of Python function.
-   *
-   * @param node
-   *   the RexNode to check
-   * @param pythonFunctionKind
-   *   the kind of the python function
-   * @return
-   *   true if it contains the Python function call in the specified node.
-   */
-  def containsPythonCall(node: RexNode, pythonFunctionKind: PythonFunctionKind 
= null): Boolean =
-    node.accept(new FunctionFinder(true, Option(pythonFunctionKind), true))
-
-  /**
-   * Checks whether it contains non-Python function call in the specified node.
-   *
-   * @param node
-   *   the RexNode to check
-   * @return
-   *   true if it contains the non-Python function call in the specified node.
-   */
-  def containsNonPythonCall(node: RexNode): Boolean =
-    node.accept(new FunctionFinder(false, None, true))
-
-  /**
-   * Checks whether the specified node is the specified kind of Python 
function call. If the
-   * parameter pythonFunctionKind is null, it will return true for any kind of 
Python function.
-   *
-   * @param node
-   *   the RexNode to check
-   * @param pythonFunctionKind
-   *   the kind of the python function
-   * @return
-   *   true if the specified node is a Python function call.
-   */
-  def isPythonCall(node: RexNode, pythonFunctionKind: PythonFunctionKind = 
null): Boolean =
-    node.accept(new FunctionFinder(true, Option(pythonFunctionKind), false))
-
-  /**
-   * Checks whether the specified node is a non-Python function call.
-   *
-   * @param node
-   *   the RexNode to check
-   * @return
-   *   true if the specified node is a non-Python function call.
-   */
-  def isNonPythonCall(node: RexNode): Boolean = node.accept(new 
FunctionFinder(false, None, false))
-
-  /**
-   * Checks whether the specified aggregate is the specified kind of Python 
function Aggregate.
-   *
-   * @param call
-   *   the AggregateCall to check
-   * @param pythonFunctionKind
-   *   the kind of the python function
-   * @return
-   *   true if the specified call is a Python function Aggregate.
-   */
-  def isPythonAggregate(
-      call: AggregateCall,
-      pythonFunctionKind: PythonFunctionKind = null): Boolean = {
-    val aggregation = call.getAggregation
-    aggregation match {
-      case function: AggSqlFunction =>
-        isPythonFunction(function.aggregateFunction, pythonFunctionKind)
-      case function: BridgingSqlAggFunction =>
-        isPythonFunction(function.getDefinition, pythonFunctionKind)
-      case _ => false
-    }
-  }
-
-  def isBuiltInAggregate(call: AggregateCall): Boolean = {
-    val aggregation = call.getAggregation
-    aggregation match {
-      case function: AggSqlFunction =>
-        function.aggregateFunction.isInstanceOf[BuiltInAggregateFunction[_, _]]
-      case function: BridgingSqlAggFunction =>
-        function.getDefinition.isInstanceOf[DeclarativeAggregateFunction]
-      case _ => true
-    }
-  }
-
-  def takesRowAsInput(call: RexCall): Boolean = {
-    (call.getOperator match {
-      case sfc: ScalarSqlFunction => sfc.scalarFunction
-      case tfc: TableSqlFunction => tfc.udtf
-      case bsf: BridgingSqlFunction => bsf.getDefinition
-    }).asInstanceOf[PythonFunction].takesRowAsInput()
-  }
-
-  private[this] def isPythonFunction(
-      function: FunctionDefinition,
-      pythonFunctionKind: PythonFunctionKind): Boolean = {
-    function match {
-      case pythonFunction: PythonFunction =>
-        pythonFunctionKind == null || pythonFunction.getPythonFunctionKind == 
pythonFunctionKind
-      case _ => false
-    }
-  }
-
-  def isFlattenCalc(calc: FlinkLogicalCalc): Boolean = {
-    val child = calc.getInput match {
-      case relSubset: RelSubset => relSubset.getOriginal
-      case hepRelVertex: HepRelVertex => hepRelVertex.getCurrentRel
-    }
-    if (!child.isInstanceOf[FlinkLogicalCalc]) {
-      return false
-    }
-
-    if (calc.getProgram.getCondition != null) {
-      return false
-    }
-
-    val inputFields = calc.getProgram.getInputRowType.getFieldList
-    if (inputFields.size != 1 || !inputFields.get(0).getValue.isStruct) {
-      return false
-    }
-
-    val projects = 
calc.getProgram.getProjectList.map(calc.getProgram.expandLocalRef)
-
-    if (inputFields.get(0).getValue.getFieldList.size() != projects.size) {
-      return false
-    }
-
-    projects.zipWithIndex.forall {
-      case (project: RexNode, idx: Int) => project.accept(new 
FieldReferenceDetector(idx))
-    }
-  }
-
-  /**
-   * Checks whether it contains the specified kind of function in a RexNode.
-   *
-   * @param findPythonFunction
-   *   true to find python function, false to find non-python function
-   * @param pythonFunctionKind
-   *   the kind of the python function
-   * @param recursive
-   *   whether check the inputs
-   */
-  private class FunctionFinder(
-      findPythonFunction: Boolean,
-      pythonFunctionKind: Option[PythonFunctionKind],
-      recursive: Boolean)
-    extends RexDefaultVisitor[Boolean] {
-
-    /**
-     * Checks whether the specified rexCall is a python function call of the 
specified kind.
-     *
-     * @param rexCall
-     *   the RexCall to check.
-     * @return
-     *   true if it is python function call of the specified kind.
-     */
-    private def isPythonRexCall(rexCall: RexCall): Boolean =
-      rexCall.getOperator match {
-        case sfc: ScalarSqlFunction => isPythonFunction(sfc.scalarFunction)
-        case tfc: TableSqlFunction => isPythonFunction(tfc.udtf)
-        case bsf: BridgingSqlFunction => isPythonFunction(bsf.getDefinition)
-        case _ => false
-      }
-
-    private def isPythonFunction(functionDefinition: FunctionDefinition): 
Boolean = {
-      functionDefinition.isInstanceOf[PythonFunction] &&
-      (pythonFunctionKind.isEmpty ||
-        functionDefinition.asInstanceOf[PythonFunction].getPythonFunctionKind 
==
-        pythonFunctionKind.get)
-    }
-
-    override def visitCall(call: RexCall): Boolean = {
-      findPythonFunction == isPythonRexCall(call) ||
-      (recursive && call.getOperands.exists(_.accept(this)))
-    }
-
-    override def visitFieldAccess(fieldAccess: RexFieldAccess): Boolean = {
-      fieldAccess.getReferenceExpr.accept(this)
-    }
-
-    override def visitNode(rexNode: RexNode): Boolean = false
-  }
-
-  /** Checks whether a rexNode is only a field reference of the given index. */
-  private class FieldReferenceDetector(idx: Int) extends 
RexDefaultVisitor[Boolean] {
-
-    override def visitNode(rexNode: RexNode): Boolean = false
-
-    override def visitFieldAccess(fieldAccess: RexFieldAccess): Boolean = {
-      if (fieldAccess.getField.getIndex != idx) {
-        return false
-      }
-
-      val expr: RexNode = fieldAccess.getReferenceExpr
-      expr match {
-        case ref: RexInputRef => ref.getIndex == 0
-        case _ => false
-      }
-    }
-
-    override def visitCall(call: RexCall): Boolean = {
-      if (call.getKind == SqlKind.AS) {
-        call.getOperands.get(0).accept(this)
-      } else {
-        false
-      }
-    }
-  }
-}


Reply via email to