This is an automated email from the ASF dual-hosted git repository.

dwysakowicz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 53616d61e22 [FLINK-37725] Makes Async calcs share correlate split rule 
with Python (#26505)
53616d61e22 is described below

commit 53616d61e22491931bec10573eeee316ca762732
Author: Alan Sheinberg <[email protected]>
AuthorDate: Thu May 15 05:03:48 2025 -0700

    [FLINK-37725] Makes Async calcs share correlate split rule with Python 
(#26505)
---
 .../plan/rules/logical/AsyncCalcSplitRule.java     |  15 ++
 .../rules/logical/AsyncCorrelateSplitRule.java     |  40 +++
 .../rules/logical/PythonCorrelateSplitRule.java    | 296 +--------------------
 ...plitRule.java => RemoteCorrelateSplitRule.java} |  99 +++++--
 .../planner/plan/rules/FlinkStreamRuleSets.scala   |   4 +-
 .../plan/rules/logical/PythonCalcSplitRule.scala   |  10 +
 .../plan/rules/logical/RemoteCalcCallFinder.java   |   3 +
 .../plan/rules/logical/AsyncCalcSplitRuleTest.java |  11 +
 .../rules/logical/AsyncCorrelateSplitRuleTest.java |  85 ++++++
 .../runtime/stream/table/AsyncCalcITCase.java      |  26 ++
 .../rules/logical/AsyncCorrelateSplitRuleTest.xml  |  84 ++++++
 11 files changed, 353 insertions(+), 320 deletions(-)

diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRule.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRule.java
index aa0b8b85528..cc7c72a909e 100644
--- 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRule.java
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRule.java
@@ -90,6 +90,21 @@ public class AsyncCalcSplitRule {
         public boolean isNonRemoteCall(RexNode node) {
             return AsyncUtil.isNonAsyncCall(node);
         }
+
+        @Override
+        public String getName() {
+            return "Async";
+        }
+
+        @Override
+        public boolean equals(Object obj) {
+            return obj != null && this.getClass() == obj.getClass();
+        }
+
+        @Override
+        public int hashCode() {
+            return this.getClass().hashCode();
+        }
     }
 
     private static boolean hasNestedCalls(List<RexNode> projects) {
diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCorrelateSplitRule.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCorrelateSplitRule.java
new file mode 100644
index 00000000000..0094094c0f1
--- /dev/null
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCorrelateSplitRule.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.logical;
+
+import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc;
+import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCorrelate;
+import 
org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan;
+import 
org.apache.flink.table.planner.plan.rules.logical.AsyncCalcSplitRule.AsyncRemoteCalcCallFinder;
+
+import org.apache.calcite.plan.RelOptRule;
+
+/**
+ * Rule will split the Async {@link FlinkLogicalTableFunctionScan} with Java 
calls or the Java
+ * {@link FlinkLogicalTableFunctionScan} with Async calls into a {@link 
FlinkLogicalCalc} which will
+ * be the left input of the new {@link FlinkLogicalCorrelate} and a new {@link
+ * FlinkLogicalTableFunctionScan}.
+ */
+public class AsyncCorrelateSplitRule {
+
+    private static final RemoteCalcCallFinder ASYNC_CALL_FINDER = new 
AsyncRemoteCalcCallFinder();
+
+    public static final RelOptRule INSTANCE =
+            
RemoteCorrelateSplitRule.Config.createDefault(ASYNC_CALL_FINDER).toRule();
+}
diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRule.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRule.java
index 986f0fc538c..5a60649e8ae 100644
--- 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRule.java
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRule.java
@@ -21,33 +21,8 @@ package org.apache.flink.table.planner.plan.rules.logical;
 import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc;
 import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCorrelate;
 import 
org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan;
-import 
org.apache.flink.table.planner.plan.rules.physical.stream.StreamPhysicalCorrelateRule;
-import org.apache.flink.table.planner.plan.utils.PythonUtil;
-import org.apache.flink.table.planner.plan.utils.RexDefaultVisitor;
 
 import org.apache.calcite.plan.RelOptRule;
-import org.apache.calcite.plan.RelOptRuleCall;
-import org.apache.calcite.plan.hep.HepRelVertex;
-import org.apache.calcite.rel.RelNode;
-import org.apache.calcite.rel.type.RelDataType;
-import org.apache.calcite.rel.type.RelDataTypeField;
-import org.apache.calcite.rex.RexBuilder;
-import org.apache.calcite.rex.RexCall;
-import org.apache.calcite.rex.RexCorrelVariable;
-import org.apache.calcite.rex.RexFieldAccess;
-import org.apache.calcite.rex.RexInputRef;
-import org.apache.calcite.rex.RexNode;
-import org.apache.calcite.rex.RexProgram;
-import org.apache.calcite.rex.RexProgramBuilder;
-import org.apache.calcite.rex.RexUtil;
-import org.apache.calcite.sql.validate.SqlValidatorUtil;
-
-import java.util.LinkedList;
-import java.util.List;
-import java.util.stream.Collectors;
-
-import scala.collection.Iterator;
-import scala.collection.mutable.ArrayBuffer;
 
 /**
  * Rule will split the Python {@link FlinkLogicalTableFunctionScan} with Java 
calls or the Java
@@ -55,272 +30,9 @@ import scala.collection.mutable.ArrayBuffer;
  * will be the left input of the new {@link FlinkLogicalCorrelate} and a new 
{@link
  * FlinkLogicalTableFunctionScan}.
  */
-public class PythonCorrelateSplitRule extends RelOptRule {
-    public static final PythonCorrelateSplitRule INSTANCE = new 
PythonCorrelateSplitRule();
-
-    private PythonCorrelateSplitRule() {
-        super(operand(FlinkLogicalCorrelate.class, any()), 
"PythonCorrelateSplitRule");
-    }
-
-    private FlinkLogicalTableFunctionScan createNewScan(
-            FlinkLogicalTableFunctionScan scan, ScalarFunctionSplitter 
splitter) {
-        RexCall rightRexCall = (RexCall) scan.getCall();
-        // extract Java funcs from Python TableFunction or Python funcs from 
Java TableFunction.
-        List<RexNode> rightCalcProjects =
-                rightRexCall.getOperands().stream()
-                        .map(x -> x.accept(splitter))
-                        .collect(Collectors.toList());
-
-        RexCall newRightRexCall = rightRexCall.clone(rightRexCall.getType(), 
rightCalcProjects);
-        return new FlinkLogicalTableFunctionScan(
-                scan.getCluster(),
-                scan.getTraitSet(),
-                scan.getInputs(),
-                newRightRexCall,
-                scan.getElementType(),
-                scan.getRowType(),
-                scan.getColumnMappings());
-    }
-
-    @Override
-    public boolean matches(RelOptRuleCall call) {
-        FlinkLogicalCorrelate correlate = call.rel(0);
-        RelNode right = ((HepRelVertex) correlate.getRight()).getCurrentRel();
-        FlinkLogicalTableFunctionScan tableFunctionScan;
-        if (right instanceof FlinkLogicalTableFunctionScan) {
-            tableFunctionScan = (FlinkLogicalTableFunctionScan) right;
-        } else if (right instanceof FlinkLogicalCalc) {
-            tableFunctionScan = 
StreamPhysicalCorrelateRule.getTableScan((FlinkLogicalCalc) right);
-        } else {
-            return false;
-        }
-        RexNode rexNode = tableFunctionScan.getCall();
-        if (rexNode instanceof RexCall) {
-            return PythonUtil.isPythonCall(rexNode, null)
-                            && PythonUtil.containsNonPythonCall(rexNode)
-                    || PythonUtil.isNonPythonCall(rexNode)
-                            && PythonUtil.containsPythonCall(rexNode, null)
-                    || (PythonUtil.isPythonCall(rexNode, null)
-                            && RexUtil.containsFieldAccess(rexNode));
-        }
-        return false;
-    }
-
-    private List<String> createNewFieldNames(
-            RelDataType rowType,
-            RexBuilder rexBuilder,
-            int primitiveFieldCount,
-            ArrayBuffer<RexNode> extractedRexNodes,
-            List<RexNode> calcProjects) {
-        for (int i = 0; i < primitiveFieldCount; i++) {
-            calcProjects.add(RexInputRef.of(i, rowType));
-        }
-        // change RexCorrelVariable to RexInputRef.
-        RexDefaultVisitor<RexNode> visitor =
-                new RexDefaultVisitor<RexNode>() {
-                    @Override
-                    public RexNode visitFieldAccess(RexFieldAccess 
fieldAccess) {
-                        RexNode expr = fieldAccess.getReferenceExpr();
-                        if (expr instanceof RexCorrelVariable) {
-                            RelDataTypeField field = fieldAccess.getField();
-                            return new RexInputRef(field.getIndex(), 
field.getType());
-                        } else {
-                            return rexBuilder.makeFieldAccess(
-                                    expr.accept(this), 
fieldAccess.getField().getIndex());
-                        }
-                    }
-
-                    @Override
-                    public RexNode visitNode(RexNode rexNode) {
-                        return rexNode;
-                    }
-                };
-        // add the fields of the extracted rex calls.
-        Iterator<RexNode> iterator = extractedRexNodes.iterator();
-        while (iterator.hasNext()) {
-            RexNode rexNode = iterator.next();
-            if (rexNode instanceof RexCall) {
-                RexCall rexCall = (RexCall) rexNode;
-                List<RexNode> newProjects =
-                        rexCall.getOperands().stream()
-                                .map(x -> x.accept(visitor))
-                                .collect(Collectors.toList());
-                RexCall newRexCall = rexCall.clone(rexCall.getType(), 
newProjects);
-                calcProjects.add(newRexCall);
-            } else {
-                calcProjects.add(rexNode);
-            }
-        }
-
-        List<String> nameList = new LinkedList<>();
-        for (int i = 0; i < primitiveFieldCount; i++) {
-            nameList.add(rowType.getFieldNames().get(i));
-        }
-        Iterator<Object> indicesIterator = 
extractedRexNodes.indices().iterator();
-        while (indicesIterator.hasNext()) {
-            nameList.add("f" + indicesIterator.next());
-        }
-        return SqlValidatorUtil.uniquify(
-                nameList, 
rexBuilder.getTypeFactory().getTypeSystem().isSchemaCaseSensitive());
-    }
-
-    private FlinkLogicalCalc createNewLeftCalc(
-            RelNode left,
-            RexBuilder rexBuilder,
-            ArrayBuffer<RexNode> extractedRexNodes,
-            FlinkLogicalCorrelate correlate) {
-        // add the fields of the primitive left input.
-        List<RexNode> leftCalcProjects = new LinkedList<>();
-        RelDataType leftRowType = left.getRowType();
-        List<String> leftCalcCalcFieldNames =
-                createNewFieldNames(
-                        leftRowType,
-                        rexBuilder,
-                        leftRowType.getFieldCount(),
-                        extractedRexNodes,
-                        leftCalcProjects);
-
-        // create a new calc
-        return new FlinkLogicalCalc(
-                correlate.getCluster(),
-                correlate.getTraitSet(),
-                left,
-                RexProgram.create(
-                        leftRowType, leftCalcProjects, null, 
leftCalcCalcFieldNames, rexBuilder));
-    }
-
-    private FlinkLogicalCalc createTopCalc(
-            int primitiveLeftFieldCount,
-            RexBuilder rexBuilder,
-            ArrayBuffer<RexNode> extractedRexNodes,
-            RelDataType calcRowType,
-            FlinkLogicalCorrelate newCorrelate) {
-        RexProgram rexProgram =
-                new RexProgramBuilder(newCorrelate.getRowType(), 
rexBuilder).getProgram();
-        int offset = extractedRexNodes.size() + primitiveLeftFieldCount;
-
-        // extract correlate output RexNode.
-        List<RexNode> newTopCalcProjects =
-                rexProgram.getExprList().stream()
-                        .filter(x -> x instanceof RexInputRef)
-                        .filter(
-                                x -> {
-                                    int index = ((RexInputRef) x).getIndex();
-                                    return index < primitiveLeftFieldCount || 
index >= offset;
-                                })
-                        .collect(Collectors.toList());
-
-        return new FlinkLogicalCalc(
-                newCorrelate.getCluster(),
-                newCorrelate.getTraitSet(),
-                newCorrelate,
-                RexProgram.create(
-                        newCorrelate.getRowType(),
-                        newTopCalcProjects,
-                        null,
-                        calcRowType,
-                        rexBuilder));
-    }
-
-    private ScalarFunctionSplitter createScalarFunctionSplitter(
-            RexProgram program,
-            RexBuilder rexBuilder,
-            int primitiveLeftFieldCount,
-            ArrayBuffer<RexNode> extractedRexNodes,
-            RexNode tableFunctionNode) {
-        return new ScalarFunctionSplitter(
-                program,
-                rexBuilder,
-                primitiveLeftFieldCount,
-                extractedRexNodes,
-                node -> {
-                    if (PythonUtil.isNonPythonCall(tableFunctionNode)) {
-                        // splits the RexCalls which contain Python functions 
into separate node
-                        return PythonUtil.isPythonCall(node, null);
-                    } else if (PythonUtil.containsNonPythonCall(node)) {
-                        // splits the RexCalls which contain non-Python 
functions into separate node
-                        return PythonUtil.isNonPythonCall(node);
-                    } else {
-                        // splits the RexFieldAccesses which contain 
non-Python functions into
-                        // separate node
-                        return node instanceof RexFieldAccess;
-                    }
-                },
-                new PythonRemoteCalcCallFinder());
-    }
-
-    @Override
-    public void onMatch(RelOptRuleCall call) {
-        FlinkLogicalCorrelate correlate = call.rel(0);
-        RexBuilder rexBuilder = call.builder().getRexBuilder();
-        RelNode left = ((HepRelVertex) correlate.getLeft()).getCurrentRel();
-        RelNode right = ((HepRelVertex) correlate.getRight()).getCurrentRel();
-        int primitiveLeftFieldCount = left.getRowType().getFieldCount();
-        ArrayBuffer<RexNode> extractedRexNodes = new ArrayBuffer<>();
-
-        RelNode rightNewInput;
-        if (right instanceof FlinkLogicalTableFunctionScan) {
-            FlinkLogicalTableFunctionScan scan = 
(FlinkLogicalTableFunctionScan) right;
-            rightNewInput =
-                    createNewScan(
-                            scan,
-                            createScalarFunctionSplitter(
-                                    null,
-                                    rexBuilder,
-                                    primitiveLeftFieldCount,
-                                    extractedRexNodes,
-                                    scan.getCall()));
-        } else {
-            FlinkLogicalCalc calc = (FlinkLogicalCalc) right;
-            FlinkLogicalTableFunctionScan scan = 
StreamPhysicalCorrelateRule.getTableScan(calc);
-            FlinkLogicalCalc mergedCalc = 
StreamPhysicalCorrelateRule.getMergedCalc(calc);
-            FlinkLogicalTableFunctionScan newScan =
-                    createNewScan(
-                            scan,
-                            createScalarFunctionSplitter(
-                                    null,
-                                    rexBuilder,
-                                    primitiveLeftFieldCount,
-                                    extractedRexNodes,
-                                    scan.getCall()));
-            rightNewInput =
-                    mergedCalc.copy(mergedCalc.getTraitSet(), newScan, 
mergedCalc.getProgram());
-        }
-
-        FlinkLogicalCorrelate newCorrelate;
-        if (extractedRexNodes.size() > 0) {
-            FlinkLogicalCalc leftCalc =
-                    createNewLeftCalc(left, rexBuilder, extractedRexNodes, 
correlate);
-
-            newCorrelate =
-                    new FlinkLogicalCorrelate(
-                            correlate.getCluster(),
-                            correlate.getTraitSet(),
-                            leftCalc,
-                            rightNewInput,
-                            correlate.getCorrelationId(),
-                            correlate.getRequiredColumns(),
-                            correlate.getJoinType());
-        } else {
-            newCorrelate =
-                    new FlinkLogicalCorrelate(
-                            correlate.getCluster(),
-                            correlate.getTraitSet(),
-                            left,
-                            rightNewInput,
-                            correlate.getCorrelationId(),
-                            correlate.getRequiredColumns(),
-                            correlate.getJoinType());
-        }
-
-        FlinkLogicalCalc newTopCalc =
-                createTopCalc(
-                        primitiveLeftFieldCount,
-                        rexBuilder,
-                        extractedRexNodes,
-                        correlate.getRowType(),
-                        newCorrelate);
+public class PythonCorrelateSplitRule {
 
-        call.transformTo(newTopCalc);
-    }
+    public static final RelOptRule INSTANCE =
+            RemoteCorrelateSplitRule.Config.createDefault(new 
PythonRemoteCalcCallFinder())
+                    .toRule();
 }
diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRule.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/RemoteCorrelateSplitRule.java
similarity index 79%
copy from 
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRule.java
copy to 
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/RemoteCorrelateSplitRule.java
index 986f0fc538c..d12d72f7755 100644
--- 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRule.java
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/RemoteCorrelateSplitRule.java
@@ -22,11 +22,11 @@ import 
org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc;
 import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCorrelate;
 import 
org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan;
 import 
org.apache.flink.table.planner.plan.rules.physical.stream.StreamPhysicalCorrelateRule;
-import org.apache.flink.table.planner.plan.utils.PythonUtil;
 import org.apache.flink.table.planner.plan.utils.RexDefaultVisitor;
 
 import org.apache.calcite.plan.RelOptRule;
 import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelRule;
 import org.apache.calcite.plan.hep.HepRelVertex;
 import org.apache.calcite.rel.RelNode;
 import org.apache.calcite.rel.type.RelDataType;
@@ -41,6 +41,7 @@ import org.apache.calcite.rex.RexProgram;
 import org.apache.calcite.rex.RexProgramBuilder;
 import org.apache.calcite.rex.RexUtil;
 import org.apache.calcite.sql.validate.SqlValidatorUtil;
+import org.immutables.value.Value;
 
 import java.util.LinkedList;
 import java.util.List;
@@ -50,22 +51,24 @@ import scala.collection.Iterator;
 import scala.collection.mutable.ArrayBuffer;
 
 /**
- * Rule will split the Python {@link FlinkLogicalTableFunctionScan} with Java 
calls or the Java
- * {@link FlinkLogicalTableFunctionScan} with Python calls into a {@link 
FlinkLogicalCalc} which
+ * Rule will split the Remote {@link FlinkLogicalTableFunctionScan} with Java 
calls or the Java
+ * {@link FlinkLogicalTableFunctionScan} with Remote calls into a {@link 
FlinkLogicalCalc} which
  * will be the left input of the new {@link FlinkLogicalCorrelate} and a new 
{@link
  * FlinkLogicalTableFunctionScan}.
  */
-public class PythonCorrelateSplitRule extends RelOptRule {
-    public static final PythonCorrelateSplitRule INSTANCE = new 
PythonCorrelateSplitRule();
[email protected]
+public class RemoteCorrelateSplitRule extends 
RelRule<RemoteCorrelateSplitRule.Config> {
+    private final RemoteCalcCallFinder callFinder;
 
-    private PythonCorrelateSplitRule() {
-        super(operand(FlinkLogicalCorrelate.class, any()), 
"PythonCorrelateSplitRule");
+    RemoteCorrelateSplitRule(Config config, RemoteCalcCallFinder callFinder) {
+        super(config);
+        this.callFinder = callFinder;
     }
 
     private FlinkLogicalTableFunctionScan createNewScan(
             FlinkLogicalTableFunctionScan scan, ScalarFunctionSplitter 
splitter) {
         RexCall rightRexCall = (RexCall) scan.getCall();
-        // extract Java funcs from Python TableFunction or Python funcs from 
Java TableFunction.
+        // extract Java funcs from Remote TableFunction or Remote funcs from 
Java TableFunction.
         List<RexNode> rightCalcProjects =
                 rightRexCall.getOperands().stream()
                         .map(x -> x.accept(splitter))
@@ -96,12 +99,10 @@ public class PythonCorrelateSplitRule extends RelOptRule {
         }
         RexNode rexNode = tableFunctionScan.getCall();
         if (rexNode instanceof RexCall) {
-            return PythonUtil.isPythonCall(rexNode, null)
-                            && PythonUtil.containsNonPythonCall(rexNode)
-                    || PythonUtil.isNonPythonCall(rexNode)
-                            && PythonUtil.containsPythonCall(rexNode, null)
-                    || (PythonUtil.isPythonCall(rexNode, null)
-                            && RexUtil.containsFieldAccess(rexNode));
+            return (callFinder.isRemoteCall(rexNode) && 
callFinder.containsNonRemoteCall(rexNode))
+                    || (callFinder.isNonRemoteCall(rexNode)
+                            && callFinder.containsRemoteCall(rexNode))
+                    || (callFinder.isRemoteCall(rexNode) && 
RexUtil.containsFieldAccess(rexNode));
         }
         return false;
     }
@@ -130,6 +131,15 @@ public class PythonCorrelateSplitRule extends RelOptRule {
                         }
                     }
 
+                    @Override
+                    public RexNode visitCall(RexCall call) {
+                        List<RexNode> newProjects =
+                                call.getOperands().stream()
+                                        .map(x -> x.accept(this))
+                                        .collect(Collectors.toList());
+                        return rexBuilder.makeCall(call.getOperator(), 
newProjects);
+                    }
+
                     @Override
                     public RexNode visitNode(RexNode rexNode) {
                         return rexNode;
@@ -223,30 +233,30 @@ public class PythonCorrelateSplitRule extends RelOptRule {
     }
 
     private ScalarFunctionSplitter createScalarFunctionSplitter(
-            RexProgram program,
             RexBuilder rexBuilder,
             int primitiveLeftFieldCount,
             ArrayBuffer<RexNode> extractedRexNodes,
             RexNode tableFunctionNode) {
         return new ScalarFunctionSplitter(
-                program,
+                // The scan should not contain any local references to 
resolve, so null is passed.
+                null,
                 rexBuilder,
                 primitiveLeftFieldCount,
                 extractedRexNodes,
                 node -> {
-                    if (PythonUtil.isNonPythonCall(tableFunctionNode)) {
-                        // splits the RexCalls which contain Python functions 
into separate node
-                        return PythonUtil.isPythonCall(node, null);
-                    } else if (PythonUtil.containsNonPythonCall(node)) {
-                        // splits the RexCalls which contain non-Python 
functions into separate node
-                        return PythonUtil.isNonPythonCall(node);
+                    if (callFinder.isNonRemoteCall(tableFunctionNode)) {
+                        // splits the RexCalls which contain Remote functions 
into separate node
+                        return callFinder.isRemoteCall(node);
+                    } else if (callFinder.containsNonRemoteCall(node)) {
+                        // splits the RexCalls which contain non-Remote 
functions into separate node
+                        return callFinder.isNonRemoteCall(node);
                     } else {
-                        // splits the RexFieldAccesses which contain 
non-Python functions into
+                        // splits the RexFieldAccesses which contain 
non-Remote functions into
                         // separate node
                         return node instanceof RexFieldAccess;
                     }
                 },
-                new PythonRemoteCalcCallFinder());
+                callFinder);
     }
 
     @Override
@@ -265,7 +275,6 @@ public class PythonCorrelateSplitRule extends RelOptRule {
                     createNewScan(
                             scan,
                             createScalarFunctionSplitter(
-                                    null,
                                     rexBuilder,
                                     primitiveLeftFieldCount,
                                     extractedRexNodes,
@@ -278,7 +287,6 @@ public class PythonCorrelateSplitRule extends RelOptRule {
                     createNewScan(
                             scan,
                             createScalarFunctionSplitter(
-                                    null,
                                     rexBuilder,
                                     primitiveLeftFieldCount,
                                     extractedRexNodes,
@@ -288,7 +296,7 @@ public class PythonCorrelateSplitRule extends RelOptRule {
         }
 
         FlinkLogicalCorrelate newCorrelate;
-        if (extractedRexNodes.size() > 0) {
+        if (!extractedRexNodes.isEmpty()) {
             FlinkLogicalCalc leftCalc =
                     createNewLeftCalc(left, rexBuilder, extractedRexNodes, 
correlate);
 
@@ -323,4 +331,41 @@ public class PythonCorrelateSplitRule extends RelOptRule {
 
         call.transformTo(newTopCalc);
     }
+
+    // Consider the rules to be equal if they are the same class and their 
call finders are the same
+    // class.
+    @Override
+    public boolean equals(Object object) {
+        if (object == null || 
!object.getClass().equals(RemoteCorrelateSplitRule.class)) {
+            return false;
+        }
+        RemoteCorrelateSplitRule rule = (RemoteCorrelateSplitRule) object;
+        return callFinder.equals(rule.callFinder);
+    }
+
+    @Override
+    public int hashCode() {
+        return callFinder.hashCode();
+    }
+
+    @Value.Immutable(singleton = false)
+    public interface Config extends RelRule.Config {
+
+        public abstract RemoteCalcCallFinder callFinder();
+
+        static RemoteCorrelateSplitRule.Config 
createDefault(RemoteCalcCallFinder callFinder) {
+            return ImmutableRemoteCorrelateSplitRule.Config.builder()
+                    .callFinder(callFinder)
+                    .build()
+                    .withOperandSupplier(b0 -> 
b0.operand(FlinkLogicalCorrelate.class).anyInputs());
+        }
+
+        @Override
+        default RelOptRule toRule() {
+            return new RemoteCorrelateSplitRule(
+                    this.withDescription("RemoteCorrelateSplitRule-" + 
callFinder().getName())
+                            .as(Config.class),
+                    callFinder());
+        }
+    }
 }
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
index d029cf6e44c..74219150ed0 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
@@ -417,7 +417,9 @@ object FlinkStreamRuleSets {
     // Avoid async calls which call async calls.
     AsyncCalcSplitRule.NESTED_SPLIT,
     // Avoid having async calls in multiple projections in a single calc.
-    AsyncCalcSplitRule.ONE_PER_CALC_SPLIT
+    AsyncCalcSplitRule.ONE_PER_CALC_SPLIT,
+    // Split async calls from correlates
+    AsyncCorrelateSplitRule.INSTANCE
   )
 
   /** RuleSet to do physical optimize for stream */
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRule.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRule.scala
index 1756b9eb5e8..e5cf07025c3 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRule.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRule.scala
@@ -81,6 +81,16 @@ class PythonRemoteCalcCallFinder extends 
RemoteCalcCallFinder {
   override def isNonRemoteCall(node: RexNode): Boolean = {
     PythonUtil.isNonPythonCall(node)
   }
+
+  override def equals(obj: Any): Boolean = {
+    obj != null && obj.isInstanceOf[PythonRemoteCalcCallFinder]
+  }
+
+  override def hashCode(): Int = {
+    this.getClass.hashCode()
+  }
+
+  override def getName: String = "Python"
 }
 
 object PythonCalcSplitRule {
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/RemoteCalcCallFinder.java
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/RemoteCalcCallFinder.java
index 5209ab1b14b..68c0e281b52 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/RemoteCalcCallFinder.java
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/RemoteCalcCallFinder.java
@@ -34,4 +34,7 @@ public interface RemoteCalcCallFinder {
 
     // If the node contains directly a non-remote call.
     boolean isNonRemoteCall(RexNode node);
+
+    // A name that can be appended onto the rule
+    String getName();
 }
diff --git 
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRuleTest.java
 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRuleTest.java
index e97abdf1cfc..1098c0585b2 100644
--- 
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRuleTest.java
+++ 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRuleTest.java
@@ -22,6 +22,7 @@ import org.apache.flink.table.annotation.DataTypeHint;
 import org.apache.flink.table.api.TableConfig;
 import org.apache.flink.table.api.TableEnvironment;
 import org.apache.flink.table.functions.AsyncScalarFunction;
+import org.apache.flink.table.functions.TableFunction;
 import 
org.apache.flink.table.planner.plan.optimize.program.BatchOptimizeContext;
 import 
org.apache.flink.table.planner.plan.optimize.program.FlinkChainedProgram;
 import 
org.apache.flink.table.planner.plan.optimize.program.FlinkHepRuleSetProgramBuilder;
@@ -82,6 +83,7 @@ public class AsyncCalcSplitRuleTest extends TableTestBase {
         util.addTemporarySystemFunction("func4", new Func4());
         util.addTemporarySystemFunction("func5", new Func5());
         util.addTemporarySystemFunction("func6", new Func6());
+        util.addTemporarySystemFunction("tableFunc", new 
RandomTableFunction());
     }
 
     @Test
@@ -370,4 +372,13 @@ public class AsyncCalcSplitRuleTest extends TableTestBase {
             future.complete(param + param2);
         }
     }
+
+    /** Test function. */
+    public static class RandomTableFunction extends TableFunction<String> {
+
+        public void eval(Integer i) {
+            collect("blah " + i);
+            collect("foo " + i);
+        }
+    }
 }
diff --git 
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCorrelateSplitRuleTest.java
 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCorrelateSplitRuleTest.java
new file mode 100644
index 00000000000..77164b9185a
--- /dev/null
+++ 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCorrelateSplitRuleTest.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.logical;
+
+import org.apache.flink.table.api.TableConfig;
+import org.apache.flink.table.api.TableEnvironment;
+import 
org.apache.flink.table.planner.plan.optimize.program.FlinkChainedProgram;
+import 
org.apache.flink.table.planner.plan.optimize.program.FlinkHepRuleSetProgramBuilder;
+import 
org.apache.flink.table.planner.plan.optimize.program.HEP_RULES_EXECUTION_TYPE;
+import 
org.apache.flink.table.planner.plan.optimize.program.StreamOptimizeContext;
+import org.apache.flink.table.planner.plan.rules.FlinkStreamRuleSets;
+import 
org.apache.flink.table.planner.plan.rules.logical.AsyncCalcSplitRuleTest.Func1;
+import 
org.apache.flink.table.planner.plan.rules.logical.AsyncCalcSplitRuleTest.RandomTableFunction;
+import org.apache.flink.table.planner.utils.TableTestBase;
+import org.apache.flink.table.planner.utils.TableTestUtil;
+
+import org.apache.calcite.plan.hep.HepMatchOrder;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+/** Test for {@link AsyncCorrelateSplitRule}. */
+public class AsyncCorrelateSplitRuleTest extends TableTestBase {
+
+    private final TableTestUtil util = 
streamTestUtil(TableConfig.getDefault());
+
+    @BeforeEach
+    public void setup() {
+        FlinkChainedProgram<StreamOptimizeContext> programs = new 
FlinkChainedProgram<>();
+        programs.addLast(
+                "logical_rewrite",
+                
FlinkHepRuleSetProgramBuilder.<StreamOptimizeContext>newBuilder()
+                        
.setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_SEQUENCE())
+                        .setHepMatchOrder(HepMatchOrder.BOTTOM_UP)
+                        .add(FlinkStreamRuleSets.LOGICAL_REWRITE())
+                        .build());
+
+        TableEnvironment tEnv = util.getTableEnv();
+        tEnv.executeSql(
+                "CREATE TABLE MyTable (\n"
+                        + "  a int,\n"
+                        + "  b bigint,\n"
+                        + "  c string,\n"
+                        + "  d ARRAY<INT NOT NULL>\n"
+                        + ") WITH (\n"
+                        + "  'connector' = 'test-simple-table-source'\n"
+                        + ") ;");
+
+        util.addTemporarySystemFunction("func1", new Func1());
+        util.addTemporarySystemFunction("tableFunc", new 
RandomTableFunction());
+    }
+
+    @Test
+    public void testCorrelateImmediate() {
+        String sqlQuery = "select * FROM MyTable, LATERAL 
TABLE(tableFunc(func1(a)))";
+        util.verifyRelPlan(sqlQuery);
+    }
+
+    @Test
+    public void testCorrelateIndirect() {
+        String sqlQuery = "select * FROM MyTable, LATERAL 
TABLE(tableFunc(ABS(func1(a))))";
+        util.verifyRelPlan(sqlQuery);
+    }
+
+    @Test
+    public void testCorrelateIndirectOtherWay() {
+        String sqlQuery = "select * FROM MyTable, LATERAL 
TABLE(tableFunc(func1(ABS(a))))";
+        util.verifyRelPlan(sqlQuery);
+    }
+}
diff --git 
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncCalcITCase.java
 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncCalcITCase.java
index ec0eefe33ae..923735c0eb2 100644
--- 
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncCalcITCase.java
+++ 
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncCalcITCase.java
@@ -28,6 +28,7 @@ import 
org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
 import org.apache.flink.table.api.config.ExecutionConfigOptions;
 import org.apache.flink.table.functions.AsyncScalarFunction;
 import org.apache.flink.table.functions.FunctionContext;
+import org.apache.flink.table.functions.TableFunction;
 import org.apache.flink.table.planner.runtime.utils.StreamingTestBase;
 import org.apache.flink.types.Row;
 import org.apache.flink.types.RowKind;
@@ -240,6 +241,22 @@ public class AsyncCalcITCase extends StreamingTestBase {
         assertThat(results).containsSequence(expectedRows);
     }
 
+    @Test
+    public void testTableFuncWithAsyncCalc() {
+        Table t1 = tEnv.fromValues(1, 2).as("f1");
+        tEnv.createTemporaryView("t1", t1);
+        tEnv.createTemporarySystemFunction("func", new RandomTableFunction());
+        tEnv.createTemporarySystemFunction("addTen", new AsyncFuncAdd10());
+        final List<Row> results = executeSql("select * FROM t1, LATERAL 
TABLE(func(addTen(f1)))");
+        final List<Row> expectedRows =
+                Arrays.asList(
+                        Row.of(1, "blah 11"),
+                        Row.of(1, "foo 11"),
+                        Row.of(2, "blah 12"),
+                        Row.of(2, "foo 12"));
+        assertThat(results).containsSequence(expectedRows);
+    }
+
     @Test
     public void testMultiArgumentAsyncWithAdditionalProjection() {
         // This was the cause of a bug previously where the reference to the 
sync projection was
@@ -423,4 +440,13 @@ public class AsyncCalcITCase extends StreamingTestBase {
             executor.schedule(() -> future.complete(param1 + param2), 10, 
TimeUnit.MILLISECONDS);
         }
     }
+
+    /** A table function. */
+    public static class RandomTableFunction extends TableFunction<String> {
+
+        public void eval(Integer i) {
+            collect("blah " + i);
+            collect("foo " + i);
+        }
+    }
 }
diff --git 
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/AsyncCorrelateSplitRuleTest.xml
 
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/AsyncCorrelateSplitRuleTest.xml
new file mode 100644
index 00000000000..00b283f7b1c
--- /dev/null
+++ 
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/AsyncCorrelateSplitRuleTest.xml
@@ -0,0 +1,84 @@
+<?xml version="1.0" ?>
+<!--
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements.  See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to you under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License.  You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+<Root>
+       <TestCase name="testCorrelateImmediate">
+               <Resource name="sql">
+                       <![CDATA[select * FROM MyTable, LATERAL 
TABLE(tableFunc(func1(a)))]]>
+               </Resource>
+               <Resource name="ast">
+                       <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], EXPR$0=[$4])
++- LogicalCorrelate(correlation=[$cor0], joinType=[inner], 
requiredColumns=[{0}])
+   :- LogicalTableScan(table=[[default_catalog, default_database, MyTable]])
+   +- LogicalTableFunctionScan(invocation=[tableFunc(func1($cor0.a))], 
rowType=[RecordType(VARCHAR(2147483647) EXPR$0)])
+]]>
+               </Resource>
+               <Resource name="optimized rel plan">
+                       <![CDATA[
+Calc(select=[a, b, c, d, EXPR$0])
++- Correlate(invocation=[tableFunc($4)], correlate=[table(tableFunc(f0))], 
select=[a,b,c,d,f0,EXPR$0], rowType=[RecordType(INTEGER a, BIGINT b, 
VARCHAR(2147483647) c, INTEGER ARRAY d, INTEGER f0, VARCHAR(2147483647) 
EXPR$0)], joinType=[INNER])
+   +- AsyncCalc(select=[a, b, c, d, func1(a) AS f0])
+      +- TableSourceScan(table=[[default_catalog, default_database, MyTable]], 
fields=[a, b, c, d])
+]]>
+               </Resource>
+       </TestCase>
+
+       <TestCase name="testCorrelateIndirect">
+               <Resource name="sql">
+                       <![CDATA[select * FROM MyTable, LATERAL 
TABLE(tableFunc(ABS(func1(a))))]]>
+               </Resource>
+               <Resource name="ast">
+                       <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], EXPR$0=[$4])
++- LogicalCorrelate(correlation=[$cor0], joinType=[inner], 
requiredColumns=[{0}])
+   :- LogicalTableScan(table=[[default_catalog, default_database, MyTable]])
+   +- LogicalTableFunctionScan(invocation=[tableFunc(ABS(func1($cor0.a)))], 
rowType=[RecordType(VARCHAR(2147483647) EXPR$0)])
+]]>
+               </Resource>
+               <Resource name="optimized rel plan">
+                       <![CDATA[
+Calc(select=[a, b, c, d, EXPR$0])
++- Correlate(invocation=[tableFunc(ABS($4))], 
correlate=[table(tableFunc(ABS(f0)))], select=[a,b,c,d,f0,EXPR$0], 
rowType=[RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, INTEGER ARRAY 
d, INTEGER f0, VARCHAR(2147483647) EXPR$0)], joinType=[INNER])
+   +- AsyncCalc(select=[a, b, c, d, func1(a) AS f0])
+      +- TableSourceScan(table=[[default_catalog, default_database, MyTable]], 
fields=[a, b, c, d])
+]]>
+               </Resource>
+       </TestCase>
+
+       <TestCase name="testCorrelateIndirectOtherWay">
+               <Resource name="sql">
+                       <![CDATA[select * FROM MyTable, LATERAL 
TABLE(tableFunc(func1(ABS(a))))]]>
+               </Resource>
+               <Resource name="ast">
+                       <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], EXPR$0=[$4])
++- LogicalCorrelate(correlation=[$cor0], joinType=[inner], 
requiredColumns=[{0}])
+   :- LogicalTableScan(table=[[default_catalog, default_database, MyTable]])
+   +- LogicalTableFunctionScan(invocation=[tableFunc(func1(ABS($cor0.a)))], 
rowType=[RecordType(VARCHAR(2147483647) EXPR$0)])
+]]>
+               </Resource>
+               <Resource name="optimized rel plan">
+                       <![CDATA[
+Calc(select=[a, b, c, d, EXPR$0])
++- Correlate(invocation=[tableFunc($4)], correlate=[table(tableFunc(f0))], 
select=[a,b,c,d,f0,EXPR$0], rowType=[RecordType(INTEGER a, BIGINT b, 
VARCHAR(2147483647) c, INTEGER ARRAY d, INTEGER f0, VARCHAR(2147483647) 
EXPR$0)], joinType=[INNER])
+   +- AsyncCalc(select=[a, b, c, d, func1(ABS(a)) AS f0])
+      +- TableSourceScan(table=[[default_catalog, default_database, MyTable]], 
fields=[a, b, c, d])
+]]>
+               </Resource>
+       </TestCase>
+</Root>


Reply via email to