This is an automated email from the ASF dual-hosted git repository.
dwysakowicz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 53616d61e22 [FLINK-37725] Makes Async calcs share correlate split rule
with Python (#26505)
53616d61e22 is described below
commit 53616d61e22491931bec10573eeee316ca762732
Author: Alan Sheinberg <[email protected]>
AuthorDate: Thu May 15 05:03:48 2025 -0700
[FLINK-37725] Makes Async calcs share correlate split rule with Python
(#26505)
---
.../plan/rules/logical/AsyncCalcSplitRule.java | 15 ++
.../rules/logical/AsyncCorrelateSplitRule.java | 40 +++
.../rules/logical/PythonCorrelateSplitRule.java | 296 +--------------------
...plitRule.java => RemoteCorrelateSplitRule.java} | 99 +++++--
.../planner/plan/rules/FlinkStreamRuleSets.scala | 4 +-
.../plan/rules/logical/PythonCalcSplitRule.scala | 10 +
.../plan/rules/logical/RemoteCalcCallFinder.java | 3 +
.../plan/rules/logical/AsyncCalcSplitRuleTest.java | 11 +
.../rules/logical/AsyncCorrelateSplitRuleTest.java | 85 ++++++
.../runtime/stream/table/AsyncCalcITCase.java | 26 ++
.../rules/logical/AsyncCorrelateSplitRuleTest.xml | 84 ++++++
11 files changed, 353 insertions(+), 320 deletions(-)
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRule.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRule.java
index aa0b8b85528..cc7c72a909e 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRule.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRule.java
@@ -90,6 +90,21 @@ public class AsyncCalcSplitRule {
public boolean isNonRemoteCall(RexNode node) {
return AsyncUtil.isNonAsyncCall(node);
}
+
+ @Override
+ public String getName() {
+ return "Async";
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ return obj != null && this.getClass() == obj.getClass();
+ }
+
+ @Override
+ public int hashCode() {
+ return this.getClass().hashCode();
+ }
}
private static boolean hasNestedCalls(List<RexNode> projects) {
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCorrelateSplitRule.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCorrelateSplitRule.java
new file mode 100644
index 00000000000..0094094c0f1
--- /dev/null
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCorrelateSplitRule.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.logical;
+
+import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc;
+import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCorrelate;
+import
org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan;
+import
org.apache.flink.table.planner.plan.rules.logical.AsyncCalcSplitRule.AsyncRemoteCalcCallFinder;
+
+import org.apache.calcite.plan.RelOptRule;
+
+/**
+ * Rule will split the Async {@link FlinkLogicalTableFunctionScan} with Java
calls or the Java
+ * {@link FlinkLogicalTableFunctionScan} with Async calls into a {@link
FlinkLogicalCalc} which will
+ * be the left input of the new {@link FlinkLogicalCorrelate} and a new {@link
+ * FlinkLogicalTableFunctionScan}.
+ */
+public class AsyncCorrelateSplitRule {
+
+ private static final RemoteCalcCallFinder ASYNC_CALL_FINDER = new
AsyncRemoteCalcCallFinder();
+
+ public static final RelOptRule INSTANCE =
+
RemoteCorrelateSplitRule.Config.createDefault(ASYNC_CALL_FINDER).toRule();
+}
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRule.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRule.java
index 986f0fc538c..5a60649e8ae 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRule.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRule.java
@@ -21,33 +21,8 @@ package org.apache.flink.table.planner.plan.rules.logical;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCorrelate;
import
org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan;
-import
org.apache.flink.table.planner.plan.rules.physical.stream.StreamPhysicalCorrelateRule;
-import org.apache.flink.table.planner.plan.utils.PythonUtil;
-import org.apache.flink.table.planner.plan.utils.RexDefaultVisitor;
import org.apache.calcite.plan.RelOptRule;
-import org.apache.calcite.plan.RelOptRuleCall;
-import org.apache.calcite.plan.hep.HepRelVertex;
-import org.apache.calcite.rel.RelNode;
-import org.apache.calcite.rel.type.RelDataType;
-import org.apache.calcite.rel.type.RelDataTypeField;
-import org.apache.calcite.rex.RexBuilder;
-import org.apache.calcite.rex.RexCall;
-import org.apache.calcite.rex.RexCorrelVariable;
-import org.apache.calcite.rex.RexFieldAccess;
-import org.apache.calcite.rex.RexInputRef;
-import org.apache.calcite.rex.RexNode;
-import org.apache.calcite.rex.RexProgram;
-import org.apache.calcite.rex.RexProgramBuilder;
-import org.apache.calcite.rex.RexUtil;
-import org.apache.calcite.sql.validate.SqlValidatorUtil;
-
-import java.util.LinkedList;
-import java.util.List;
-import java.util.stream.Collectors;
-
-import scala.collection.Iterator;
-import scala.collection.mutable.ArrayBuffer;
/**
* Rule will split the Python {@link FlinkLogicalTableFunctionScan} with Java
calls or the Java
@@ -55,272 +30,9 @@ import scala.collection.mutable.ArrayBuffer;
* will be the left input of the new {@link FlinkLogicalCorrelate} and a new
{@link
* FlinkLogicalTableFunctionScan}.
*/
-public class PythonCorrelateSplitRule extends RelOptRule {
- public static final PythonCorrelateSplitRule INSTANCE = new
PythonCorrelateSplitRule();
-
- private PythonCorrelateSplitRule() {
- super(operand(FlinkLogicalCorrelate.class, any()),
"PythonCorrelateSplitRule");
- }
-
- private FlinkLogicalTableFunctionScan createNewScan(
- FlinkLogicalTableFunctionScan scan, ScalarFunctionSplitter
splitter) {
- RexCall rightRexCall = (RexCall) scan.getCall();
- // extract Java funcs from Python TableFunction or Python funcs from
Java TableFunction.
- List<RexNode> rightCalcProjects =
- rightRexCall.getOperands().stream()
- .map(x -> x.accept(splitter))
- .collect(Collectors.toList());
-
- RexCall newRightRexCall = rightRexCall.clone(rightRexCall.getType(),
rightCalcProjects);
- return new FlinkLogicalTableFunctionScan(
- scan.getCluster(),
- scan.getTraitSet(),
- scan.getInputs(),
- newRightRexCall,
- scan.getElementType(),
- scan.getRowType(),
- scan.getColumnMappings());
- }
-
- @Override
- public boolean matches(RelOptRuleCall call) {
- FlinkLogicalCorrelate correlate = call.rel(0);
- RelNode right = ((HepRelVertex) correlate.getRight()).getCurrentRel();
- FlinkLogicalTableFunctionScan tableFunctionScan;
- if (right instanceof FlinkLogicalTableFunctionScan) {
- tableFunctionScan = (FlinkLogicalTableFunctionScan) right;
- } else if (right instanceof FlinkLogicalCalc) {
- tableFunctionScan =
StreamPhysicalCorrelateRule.getTableScan((FlinkLogicalCalc) right);
- } else {
- return false;
- }
- RexNode rexNode = tableFunctionScan.getCall();
- if (rexNode instanceof RexCall) {
- return PythonUtil.isPythonCall(rexNode, null)
- && PythonUtil.containsNonPythonCall(rexNode)
- || PythonUtil.isNonPythonCall(rexNode)
- && PythonUtil.containsPythonCall(rexNode, null)
- || (PythonUtil.isPythonCall(rexNode, null)
- && RexUtil.containsFieldAccess(rexNode));
- }
- return false;
- }
-
- private List<String> createNewFieldNames(
- RelDataType rowType,
- RexBuilder rexBuilder,
- int primitiveFieldCount,
- ArrayBuffer<RexNode> extractedRexNodes,
- List<RexNode> calcProjects) {
- for (int i = 0; i < primitiveFieldCount; i++) {
- calcProjects.add(RexInputRef.of(i, rowType));
- }
- // change RexCorrelVariable to RexInputRef.
- RexDefaultVisitor<RexNode> visitor =
- new RexDefaultVisitor<RexNode>() {
- @Override
- public RexNode visitFieldAccess(RexFieldAccess
fieldAccess) {
- RexNode expr = fieldAccess.getReferenceExpr();
- if (expr instanceof RexCorrelVariable) {
- RelDataTypeField field = fieldAccess.getField();
- return new RexInputRef(field.getIndex(),
field.getType());
- } else {
- return rexBuilder.makeFieldAccess(
- expr.accept(this),
fieldAccess.getField().getIndex());
- }
- }
-
- @Override
- public RexNode visitNode(RexNode rexNode) {
- return rexNode;
- }
- };
- // add the fields of the extracted rex calls.
- Iterator<RexNode> iterator = extractedRexNodes.iterator();
- while (iterator.hasNext()) {
- RexNode rexNode = iterator.next();
- if (rexNode instanceof RexCall) {
- RexCall rexCall = (RexCall) rexNode;
- List<RexNode> newProjects =
- rexCall.getOperands().stream()
- .map(x -> x.accept(visitor))
- .collect(Collectors.toList());
- RexCall newRexCall = rexCall.clone(rexCall.getType(),
newProjects);
- calcProjects.add(newRexCall);
- } else {
- calcProjects.add(rexNode);
- }
- }
-
- List<String> nameList = new LinkedList<>();
- for (int i = 0; i < primitiveFieldCount; i++) {
- nameList.add(rowType.getFieldNames().get(i));
- }
- Iterator<Object> indicesIterator =
extractedRexNodes.indices().iterator();
- while (indicesIterator.hasNext()) {
- nameList.add("f" + indicesIterator.next());
- }
- return SqlValidatorUtil.uniquify(
- nameList,
rexBuilder.getTypeFactory().getTypeSystem().isSchemaCaseSensitive());
- }
-
- private FlinkLogicalCalc createNewLeftCalc(
- RelNode left,
- RexBuilder rexBuilder,
- ArrayBuffer<RexNode> extractedRexNodes,
- FlinkLogicalCorrelate correlate) {
- // add the fields of the primitive left input.
- List<RexNode> leftCalcProjects = new LinkedList<>();
- RelDataType leftRowType = left.getRowType();
- List<String> leftCalcCalcFieldNames =
- createNewFieldNames(
- leftRowType,
- rexBuilder,
- leftRowType.getFieldCount(),
- extractedRexNodes,
- leftCalcProjects);
-
- // create a new calc
- return new FlinkLogicalCalc(
- correlate.getCluster(),
- correlate.getTraitSet(),
- left,
- RexProgram.create(
- leftRowType, leftCalcProjects, null,
leftCalcCalcFieldNames, rexBuilder));
- }
-
- private FlinkLogicalCalc createTopCalc(
- int primitiveLeftFieldCount,
- RexBuilder rexBuilder,
- ArrayBuffer<RexNode> extractedRexNodes,
- RelDataType calcRowType,
- FlinkLogicalCorrelate newCorrelate) {
- RexProgram rexProgram =
- new RexProgramBuilder(newCorrelate.getRowType(),
rexBuilder).getProgram();
- int offset = extractedRexNodes.size() + primitiveLeftFieldCount;
-
- // extract correlate output RexNode.
- List<RexNode> newTopCalcProjects =
- rexProgram.getExprList().stream()
- .filter(x -> x instanceof RexInputRef)
- .filter(
- x -> {
- int index = ((RexInputRef) x).getIndex();
- return index < primitiveLeftFieldCount ||
index >= offset;
- })
- .collect(Collectors.toList());
-
- return new FlinkLogicalCalc(
- newCorrelate.getCluster(),
- newCorrelate.getTraitSet(),
- newCorrelate,
- RexProgram.create(
- newCorrelate.getRowType(),
- newTopCalcProjects,
- null,
- calcRowType,
- rexBuilder));
- }
-
- private ScalarFunctionSplitter createScalarFunctionSplitter(
- RexProgram program,
- RexBuilder rexBuilder,
- int primitiveLeftFieldCount,
- ArrayBuffer<RexNode> extractedRexNodes,
- RexNode tableFunctionNode) {
- return new ScalarFunctionSplitter(
- program,
- rexBuilder,
- primitiveLeftFieldCount,
- extractedRexNodes,
- node -> {
- if (PythonUtil.isNonPythonCall(tableFunctionNode)) {
- // splits the RexCalls which contain Python functions
into separate node
- return PythonUtil.isPythonCall(node, null);
- } else if (PythonUtil.containsNonPythonCall(node)) {
- // splits the RexCalls which contain non-Python
functions into separate node
- return PythonUtil.isNonPythonCall(node);
- } else {
- // splits the RexFieldAccesses which contain
non-Python functions into
- // separate node
- return node instanceof RexFieldAccess;
- }
- },
- new PythonRemoteCalcCallFinder());
- }
-
- @Override
- public void onMatch(RelOptRuleCall call) {
- FlinkLogicalCorrelate correlate = call.rel(0);
- RexBuilder rexBuilder = call.builder().getRexBuilder();
- RelNode left = ((HepRelVertex) correlate.getLeft()).getCurrentRel();
- RelNode right = ((HepRelVertex) correlate.getRight()).getCurrentRel();
- int primitiveLeftFieldCount = left.getRowType().getFieldCount();
- ArrayBuffer<RexNode> extractedRexNodes = new ArrayBuffer<>();
-
- RelNode rightNewInput;
- if (right instanceof FlinkLogicalTableFunctionScan) {
- FlinkLogicalTableFunctionScan scan =
(FlinkLogicalTableFunctionScan) right;
- rightNewInput =
- createNewScan(
- scan,
- createScalarFunctionSplitter(
- null,
- rexBuilder,
- primitiveLeftFieldCount,
- extractedRexNodes,
- scan.getCall()));
- } else {
- FlinkLogicalCalc calc = (FlinkLogicalCalc) right;
- FlinkLogicalTableFunctionScan scan =
StreamPhysicalCorrelateRule.getTableScan(calc);
- FlinkLogicalCalc mergedCalc =
StreamPhysicalCorrelateRule.getMergedCalc(calc);
- FlinkLogicalTableFunctionScan newScan =
- createNewScan(
- scan,
- createScalarFunctionSplitter(
- null,
- rexBuilder,
- primitiveLeftFieldCount,
- extractedRexNodes,
- scan.getCall()));
- rightNewInput =
- mergedCalc.copy(mergedCalc.getTraitSet(), newScan,
mergedCalc.getProgram());
- }
-
- FlinkLogicalCorrelate newCorrelate;
- if (extractedRexNodes.size() > 0) {
- FlinkLogicalCalc leftCalc =
- createNewLeftCalc(left, rexBuilder, extractedRexNodes,
correlate);
-
- newCorrelate =
- new FlinkLogicalCorrelate(
- correlate.getCluster(),
- correlate.getTraitSet(),
- leftCalc,
- rightNewInput,
- correlate.getCorrelationId(),
- correlate.getRequiredColumns(),
- correlate.getJoinType());
- } else {
- newCorrelate =
- new FlinkLogicalCorrelate(
- correlate.getCluster(),
- correlate.getTraitSet(),
- left,
- rightNewInput,
- correlate.getCorrelationId(),
- correlate.getRequiredColumns(),
- correlate.getJoinType());
- }
-
- FlinkLogicalCalc newTopCalc =
- createTopCalc(
- primitiveLeftFieldCount,
- rexBuilder,
- extractedRexNodes,
- correlate.getRowType(),
- newCorrelate);
+public class PythonCorrelateSplitRule {
- call.transformTo(newTopCalc);
- }
+ public static final RelOptRule INSTANCE =
+ RemoteCorrelateSplitRule.Config.createDefault(new
PythonRemoteCalcCallFinder())
+ .toRule();
}
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRule.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/RemoteCorrelateSplitRule.java
similarity index 79%
copy from
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRule.java
copy to
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/RemoteCorrelateSplitRule.java
index 986f0fc538c..d12d72f7755 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/PythonCorrelateSplitRule.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/rules/logical/RemoteCorrelateSplitRule.java
@@ -22,11 +22,11 @@ import
org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCalc;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalCorrelate;
import
org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalTableFunctionScan;
import
org.apache.flink.table.planner.plan.rules.physical.stream.StreamPhysicalCorrelateRule;
-import org.apache.flink.table.planner.plan.utils.PythonUtil;
import org.apache.flink.table.planner.plan.utils.RexDefaultVisitor;
import org.apache.calcite.plan.RelOptRule;
import org.apache.calcite.plan.RelOptRuleCall;
+import org.apache.calcite.plan.RelRule;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.type.RelDataType;
@@ -41,6 +41,7 @@ import org.apache.calcite.rex.RexProgram;
import org.apache.calcite.rex.RexProgramBuilder;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.validate.SqlValidatorUtil;
+import org.immutables.value.Value;
import java.util.LinkedList;
import java.util.List;
@@ -50,22 +51,24 @@ import scala.collection.Iterator;
import scala.collection.mutable.ArrayBuffer;
/**
- * Rule will split the Python {@link FlinkLogicalTableFunctionScan} with Java
calls or the Java
- * {@link FlinkLogicalTableFunctionScan} with Python calls into a {@link
FlinkLogicalCalc} which
+ * Rule will split the Remote {@link FlinkLogicalTableFunctionScan} with Java
calls or the Java
+ * {@link FlinkLogicalTableFunctionScan} with Remote calls into a {@link
FlinkLogicalCalc} which
* will be the left input of the new {@link FlinkLogicalCorrelate} and a new
{@link
* FlinkLogicalTableFunctionScan}.
*/
-public class PythonCorrelateSplitRule extends RelOptRule {
- public static final PythonCorrelateSplitRule INSTANCE = new
PythonCorrelateSplitRule();
[email protected]
+public class RemoteCorrelateSplitRule extends
RelRule<RemoteCorrelateSplitRule.Config> {
+ private final RemoteCalcCallFinder callFinder;
- private PythonCorrelateSplitRule() {
- super(operand(FlinkLogicalCorrelate.class, any()),
"PythonCorrelateSplitRule");
+ RemoteCorrelateSplitRule(Config config, RemoteCalcCallFinder callFinder) {
+ super(config);
+ this.callFinder = callFinder;
}
private FlinkLogicalTableFunctionScan createNewScan(
FlinkLogicalTableFunctionScan scan, ScalarFunctionSplitter
splitter) {
RexCall rightRexCall = (RexCall) scan.getCall();
- // extract Java funcs from Python TableFunction or Python funcs from
Java TableFunction.
+ // extract Java funcs from Remote TableFunction or Remote funcs from
Java TableFunction.
List<RexNode> rightCalcProjects =
rightRexCall.getOperands().stream()
.map(x -> x.accept(splitter))
@@ -96,12 +99,10 @@ public class PythonCorrelateSplitRule extends RelOptRule {
}
RexNode rexNode = tableFunctionScan.getCall();
if (rexNode instanceof RexCall) {
- return PythonUtil.isPythonCall(rexNode, null)
- && PythonUtil.containsNonPythonCall(rexNode)
- || PythonUtil.isNonPythonCall(rexNode)
- && PythonUtil.containsPythonCall(rexNode, null)
- || (PythonUtil.isPythonCall(rexNode, null)
- && RexUtil.containsFieldAccess(rexNode));
+ return (callFinder.isRemoteCall(rexNode) &&
callFinder.containsNonRemoteCall(rexNode))
+ || (callFinder.isNonRemoteCall(rexNode)
+ && callFinder.containsRemoteCall(rexNode))
+ || (callFinder.isRemoteCall(rexNode) &&
RexUtil.containsFieldAccess(rexNode));
}
return false;
}
@@ -130,6 +131,15 @@ public class PythonCorrelateSplitRule extends RelOptRule {
}
}
+ @Override
+ public RexNode visitCall(RexCall call) {
+ List<RexNode> newProjects =
+ call.getOperands().stream()
+ .map(x -> x.accept(this))
+ .collect(Collectors.toList());
+ return rexBuilder.makeCall(call.getOperator(),
newProjects);
+ }
+
@Override
public RexNode visitNode(RexNode rexNode) {
return rexNode;
@@ -223,30 +233,30 @@ public class PythonCorrelateSplitRule extends RelOptRule {
}
private ScalarFunctionSplitter createScalarFunctionSplitter(
- RexProgram program,
RexBuilder rexBuilder,
int primitiveLeftFieldCount,
ArrayBuffer<RexNode> extractedRexNodes,
RexNode tableFunctionNode) {
return new ScalarFunctionSplitter(
- program,
+ // The scan should not contain any local references to
resolve, so null is passed.
+ null,
rexBuilder,
primitiveLeftFieldCount,
extractedRexNodes,
node -> {
- if (PythonUtil.isNonPythonCall(tableFunctionNode)) {
- // splits the RexCalls which contain Python functions
into separate node
- return PythonUtil.isPythonCall(node, null);
- } else if (PythonUtil.containsNonPythonCall(node)) {
- // splits the RexCalls which contain non-Python
functions into separate node
- return PythonUtil.isNonPythonCall(node);
+ if (callFinder.isNonRemoteCall(tableFunctionNode)) {
+ // splits the RexCalls which contain Remote functions
into separate node
+ return callFinder.isRemoteCall(node);
+ } else if (callFinder.containsNonRemoteCall(node)) {
+ // splits the RexCalls which contain non-Remote
functions into separate node
+ return callFinder.isNonRemoteCall(node);
} else {
- // splits the RexFieldAccesses which contain
non-Python functions into
+ // splits the RexFieldAccesses which contain
non-Remote functions into
// separate node
return node instanceof RexFieldAccess;
}
},
- new PythonRemoteCalcCallFinder());
+ callFinder);
}
@Override
@@ -265,7 +275,6 @@ public class PythonCorrelateSplitRule extends RelOptRule {
createNewScan(
scan,
createScalarFunctionSplitter(
- null,
rexBuilder,
primitiveLeftFieldCount,
extractedRexNodes,
@@ -278,7 +287,6 @@ public class PythonCorrelateSplitRule extends RelOptRule {
createNewScan(
scan,
createScalarFunctionSplitter(
- null,
rexBuilder,
primitiveLeftFieldCount,
extractedRexNodes,
@@ -288,7 +296,7 @@ public class PythonCorrelateSplitRule extends RelOptRule {
}
FlinkLogicalCorrelate newCorrelate;
- if (extractedRexNodes.size() > 0) {
+ if (!extractedRexNodes.isEmpty()) {
FlinkLogicalCalc leftCalc =
createNewLeftCalc(left, rexBuilder, extractedRexNodes,
correlate);
@@ -323,4 +331,41 @@ public class PythonCorrelateSplitRule extends RelOptRule {
call.transformTo(newTopCalc);
}
+
+ // Consider the rules to be equal if they are the same class and their
call finders are the same
+ // class.
+ @Override
+ public boolean equals(Object object) {
+ if (object == null ||
!object.getClass().equals(RemoteCorrelateSplitRule.class)) {
+ return false;
+ }
+ RemoteCorrelateSplitRule rule = (RemoteCorrelateSplitRule) object;
+ return callFinder.equals(rule.callFinder);
+ }
+
+ @Override
+ public int hashCode() {
+ return callFinder.hashCode();
+ }
+
+ @Value.Immutable(singleton = false)
+ public interface Config extends RelRule.Config {
+
+ public abstract RemoteCalcCallFinder callFinder();
+
+ static RemoteCorrelateSplitRule.Config
createDefault(RemoteCalcCallFinder callFinder) {
+ return ImmutableRemoteCorrelateSplitRule.Config.builder()
+ .callFinder(callFinder)
+ .build()
+ .withOperandSupplier(b0 ->
b0.operand(FlinkLogicalCorrelate.class).anyInputs());
+ }
+
+ @Override
+ default RelOptRule toRule() {
+ return new RemoteCorrelateSplitRule(
+ this.withDescription("RemoteCorrelateSplitRule-" +
callFinder().getName())
+ .as(Config.class),
+ callFinder());
+ }
+ }
}
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
index d029cf6e44c..74219150ed0 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/FlinkStreamRuleSets.scala
@@ -417,7 +417,9 @@ object FlinkStreamRuleSets {
// Avoid async calls which call async calls.
AsyncCalcSplitRule.NESTED_SPLIT,
// Avoid having async calls in multiple projections in a single calc.
- AsyncCalcSplitRule.ONE_PER_CALC_SPLIT
+ AsyncCalcSplitRule.ONE_PER_CALC_SPLIT,
+ // Split async calls from correlates
+ AsyncCorrelateSplitRule.INSTANCE
)
/** RuleSet to do physical optimize for stream */
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRule.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRule.scala
index 1756b9eb5e8..e5cf07025c3 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRule.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/PythonCalcSplitRule.scala
@@ -81,6 +81,16 @@ class PythonRemoteCalcCallFinder extends
RemoteCalcCallFinder {
override def isNonRemoteCall(node: RexNode): Boolean = {
PythonUtil.isNonPythonCall(node)
}
+
+ override def equals(obj: Any): Boolean = {
+ obj != null && obj.isInstanceOf[PythonRemoteCalcCallFinder]
+ }
+
+ override def hashCode(): Int = {
+ this.getClass.hashCode()
+ }
+
+ override def getName: String = "Python"
}
object PythonCalcSplitRule {
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/RemoteCalcCallFinder.java
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/RemoteCalcCallFinder.java
index 5209ab1b14b..68c0e281b52 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/RemoteCalcCallFinder.java
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/rules/logical/RemoteCalcCallFinder.java
@@ -34,4 +34,7 @@ public interface RemoteCalcCallFinder {
// If the node contains directly a non-remote call.
boolean isNonRemoteCall(RexNode node);
+
+ // A name that can be appended onto the rule
+ String getName();
}
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRuleTest.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRuleTest.java
index e97abdf1cfc..1098c0585b2 100644
---
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRuleTest.java
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCalcSplitRuleTest.java
@@ -22,6 +22,7 @@ import org.apache.flink.table.annotation.DataTypeHint;
import org.apache.flink.table.api.TableConfig;
import org.apache.flink.table.api.TableEnvironment;
import org.apache.flink.table.functions.AsyncScalarFunction;
+import org.apache.flink.table.functions.TableFunction;
import
org.apache.flink.table.planner.plan.optimize.program.BatchOptimizeContext;
import
org.apache.flink.table.planner.plan.optimize.program.FlinkChainedProgram;
import
org.apache.flink.table.planner.plan.optimize.program.FlinkHepRuleSetProgramBuilder;
@@ -82,6 +83,7 @@ public class AsyncCalcSplitRuleTest extends TableTestBase {
util.addTemporarySystemFunction("func4", new Func4());
util.addTemporarySystemFunction("func5", new Func5());
util.addTemporarySystemFunction("func6", new Func6());
+ util.addTemporarySystemFunction("tableFunc", new
RandomTableFunction());
}
@Test
@@ -370,4 +372,13 @@ public class AsyncCalcSplitRuleTest extends TableTestBase {
future.complete(param + param2);
}
}
+
+ /** Test function. */
+ public static class RandomTableFunction extends TableFunction<String> {
+
+ public void eval(Integer i) {
+ collect("blah " + i);
+ collect("foo " + i);
+ }
+ }
}
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCorrelateSplitRuleTest.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCorrelateSplitRuleTest.java
new file mode 100644
index 00000000000..77164b9185a
--- /dev/null
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/plan/rules/logical/AsyncCorrelateSplitRuleTest.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.planner.plan.rules.logical;
+
+import org.apache.flink.table.api.TableConfig;
+import org.apache.flink.table.api.TableEnvironment;
+import
org.apache.flink.table.planner.plan.optimize.program.FlinkChainedProgram;
+import
org.apache.flink.table.planner.plan.optimize.program.FlinkHepRuleSetProgramBuilder;
+import
org.apache.flink.table.planner.plan.optimize.program.HEP_RULES_EXECUTION_TYPE;
+import
org.apache.flink.table.planner.plan.optimize.program.StreamOptimizeContext;
+import org.apache.flink.table.planner.plan.rules.FlinkStreamRuleSets;
+import
org.apache.flink.table.planner.plan.rules.logical.AsyncCalcSplitRuleTest.Func1;
+import
org.apache.flink.table.planner.plan.rules.logical.AsyncCalcSplitRuleTest.RandomTableFunction;
+import org.apache.flink.table.planner.utils.TableTestBase;
+import org.apache.flink.table.planner.utils.TableTestUtil;
+
+import org.apache.calcite.plan.hep.HepMatchOrder;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+/** Test for {@link AsyncCorrelateSplitRule}. */
+public class AsyncCorrelateSplitRuleTest extends TableTestBase {
+
+ private final TableTestUtil util =
streamTestUtil(TableConfig.getDefault());
+
+ @BeforeEach
+ public void setup() {
+ FlinkChainedProgram<StreamOptimizeContext> programs = new
FlinkChainedProgram<>();
+ programs.addLast(
+ "logical_rewrite",
+
FlinkHepRuleSetProgramBuilder.<StreamOptimizeContext>newBuilder()
+
.setHepRulesExecutionType(HEP_RULES_EXECUTION_TYPE.RULE_SEQUENCE())
+ .setHepMatchOrder(HepMatchOrder.BOTTOM_UP)
+ .add(FlinkStreamRuleSets.LOGICAL_REWRITE())
+ .build());
+
+ TableEnvironment tEnv = util.getTableEnv();
+ tEnv.executeSql(
+ "CREATE TABLE MyTable (\n"
+ + " a int,\n"
+ + " b bigint,\n"
+ + " c string,\n"
+ + " d ARRAY<INT NOT NULL>\n"
+ + ") WITH (\n"
+ + " 'connector' = 'test-simple-table-source'\n"
+ + ") ;");
+
+ util.addTemporarySystemFunction("func1", new Func1());
+ util.addTemporarySystemFunction("tableFunc", new
RandomTableFunction());
+ }
+
+ @Test
+ public void testCorrelateImmediate() {
+ String sqlQuery = "select * FROM MyTable, LATERAL
TABLE(tableFunc(func1(a)))";
+ util.verifyRelPlan(sqlQuery);
+ }
+
+ @Test
+ public void testCorrelateIndirect() {
+ String sqlQuery = "select * FROM MyTable, LATERAL
TABLE(tableFunc(ABS(func1(a))))";
+ util.verifyRelPlan(sqlQuery);
+ }
+
+ @Test
+ public void testCorrelateIndirectOtherWay() {
+ String sqlQuery = "select * FROM MyTable, LATERAL
TABLE(tableFunc(func1(ABS(a))))";
+ util.verifyRelPlan(sqlQuery);
+ }
+}
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncCalcITCase.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncCalcITCase.java
index ec0eefe33ae..923735c0eb2 100644
---
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncCalcITCase.java
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/table/AsyncCalcITCase.java
@@ -28,6 +28,7 @@ import
org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.api.config.ExecutionConfigOptions;
import org.apache.flink.table.functions.AsyncScalarFunction;
import org.apache.flink.table.functions.FunctionContext;
+import org.apache.flink.table.functions.TableFunction;
import org.apache.flink.table.planner.runtime.utils.StreamingTestBase;
import org.apache.flink.types.Row;
import org.apache.flink.types.RowKind;
@@ -240,6 +241,22 @@ public class AsyncCalcITCase extends StreamingTestBase {
assertThat(results).containsSequence(expectedRows);
}
+ @Test
+ public void testTableFuncWithAsyncCalc() {
+ Table t1 = tEnv.fromValues(1, 2).as("f1");
+ tEnv.createTemporaryView("t1", t1);
+ tEnv.createTemporarySystemFunction("func", new RandomTableFunction());
+ tEnv.createTemporarySystemFunction("addTen", new AsyncFuncAdd10());
+ final List<Row> results = executeSql("select * FROM t1, LATERAL
TABLE(func(addTen(f1)))");
+ final List<Row> expectedRows =
+ Arrays.asList(
+ Row.of(1, "blah 11"),
+ Row.of(1, "foo 11"),
+ Row.of(2, "blah 12"),
+ Row.of(2, "foo 12"));
+ assertThat(results).containsSequence(expectedRows);
+ }
+
@Test
public void testMultiArgumentAsyncWithAdditionalProjection() {
// This was the cause of a bug previously where the reference to the
sync projection was
@@ -423,4 +440,13 @@ public class AsyncCalcITCase extends StreamingTestBase {
executor.schedule(() -> future.complete(param1 + param2), 10,
TimeUnit.MILLISECONDS);
}
}
+
+ /** A table function. */
+ public static class RandomTableFunction extends TableFunction<String> {
+
+ public void eval(Integer i) {
+ collect("blah " + i);
+ collect("foo " + i);
+ }
+ }
}
diff --git
a/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/AsyncCorrelateSplitRuleTest.xml
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/AsyncCorrelateSplitRuleTest.xml
new file mode 100644
index 00000000000..00b283f7b1c
--- /dev/null
+++
b/flink-table/flink-table-planner/src/test/resources/org/apache/flink/table/planner/plan/rules/logical/AsyncCorrelateSplitRuleTest.xml
@@ -0,0 +1,84 @@
+<?xml version="1.0" ?>
+<!--
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements. See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to you under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License. You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+-->
+<Root>
+ <TestCase name="testCorrelateImmediate">
+ <Resource name="sql">
+ <![CDATA[select * FROM MyTable, LATERAL
TABLE(tableFunc(func1(a)))]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], EXPR$0=[$4])
++- LogicalCorrelate(correlation=[$cor0], joinType=[inner],
requiredColumns=[{0}])
+ :- LogicalTableScan(table=[[default_catalog, default_database, MyTable]])
+ +- LogicalTableFunctionScan(invocation=[tableFunc(func1($cor0.a))],
rowType=[RecordType(VARCHAR(2147483647) EXPR$0)])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+Calc(select=[a, b, c, d, EXPR$0])
++- Correlate(invocation=[tableFunc($4)], correlate=[table(tableFunc(f0))],
select=[a,b,c,d,f0,EXPR$0], rowType=[RecordType(INTEGER a, BIGINT b,
VARCHAR(2147483647) c, INTEGER ARRAY d, INTEGER f0, VARCHAR(2147483647)
EXPR$0)], joinType=[INNER])
+ +- AsyncCalc(select=[a, b, c, d, func1(a) AS f0])
+ +- TableSourceScan(table=[[default_catalog, default_database, MyTable]],
fields=[a, b, c, d])
+]]>
+ </Resource>
+ </TestCase>
+
+ <TestCase name="testCorrelateIndirect">
+ <Resource name="sql">
+ <![CDATA[select * FROM MyTable, LATERAL
TABLE(tableFunc(ABS(func1(a))))]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], EXPR$0=[$4])
++- LogicalCorrelate(correlation=[$cor0], joinType=[inner],
requiredColumns=[{0}])
+ :- LogicalTableScan(table=[[default_catalog, default_database, MyTable]])
+ +- LogicalTableFunctionScan(invocation=[tableFunc(ABS(func1($cor0.a)))],
rowType=[RecordType(VARCHAR(2147483647) EXPR$0)])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+Calc(select=[a, b, c, d, EXPR$0])
++- Correlate(invocation=[tableFunc(ABS($4))],
correlate=[table(tableFunc(ABS(f0)))], select=[a,b,c,d,f0,EXPR$0],
rowType=[RecordType(INTEGER a, BIGINT b, VARCHAR(2147483647) c, INTEGER ARRAY
d, INTEGER f0, VARCHAR(2147483647) EXPR$0)], joinType=[INNER])
+ +- AsyncCalc(select=[a, b, c, d, func1(a) AS f0])
+ +- TableSourceScan(table=[[default_catalog, default_database, MyTable]],
fields=[a, b, c, d])
+]]>
+ </Resource>
+ </TestCase>
+
+ <TestCase name="testCorrelateIndirectOtherWay">
+ <Resource name="sql">
+ <![CDATA[select * FROM MyTable, LATERAL
TABLE(tableFunc(func1(ABS(a))))]]>
+ </Resource>
+ <Resource name="ast">
+ <![CDATA[
+LogicalProject(a=[$0], b=[$1], c=[$2], d=[$3], EXPR$0=[$4])
++- LogicalCorrelate(correlation=[$cor0], joinType=[inner],
requiredColumns=[{0}])
+ :- LogicalTableScan(table=[[default_catalog, default_database, MyTable]])
+ +- LogicalTableFunctionScan(invocation=[tableFunc(func1(ABS($cor0.a)))],
rowType=[RecordType(VARCHAR(2147483647) EXPR$0)])
+]]>
+ </Resource>
+ <Resource name="optimized rel plan">
+ <![CDATA[
+Calc(select=[a, b, c, d, EXPR$0])
++- Correlate(invocation=[tableFunc($4)], correlate=[table(tableFunc(f0))],
select=[a,b,c,d,f0,EXPR$0], rowType=[RecordType(INTEGER a, BIGINT b,
VARCHAR(2147483647) c, INTEGER ARRAY d, INTEGER f0, VARCHAR(2147483647)
EXPR$0)], joinType=[INNER])
+ +- AsyncCalc(select=[a, b, c, d, func1(ABS(a)) AS f0])
+ +- TableSourceScan(table=[[default_catalog, default_database, MyTable]],
fields=[a, b, c, d])
+]]>
+ </Resource>
+ </TestCase>
+</Root>