This is an automated email from the ASF dual-hosted git repository.
shengkai pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 8b6e69e6f7e [FLINK-38435][table] Refactor codegen and runner for
MLPredict and LookupJoin (#27041)
8b6e69e6f7e is described below
commit 8b6e69e6f7ecb105dd7478a3613c2a656de4e9e9
Author: Shengkai <[email protected]>
AuthorDate: Thu Oct 16 09:40:08 2025 +0800
[FLINK-38435][table] Refactor codegen and runner for MLPredict and
LookupJoin (#27041)
---
...upCallContext.java => FunctionCallContext.java} | 24 +-
.../nodes/exec/common/CommonExecLookupJoin.java | 3 +-
.../nodes/exec/stream/StreamExecDeltaJoin.java | 3 +-
.../stream/StreamExecMLPredictTableFunction.java | 69 +---
.../codegen/FunctionCallCodeGenerator.scala | 344 ++++++++++++++++++
.../planner/codegen/LookupJoinCodeGenerator.scala | 403 +++++++--------------
.../planner/codegen/MLPredictCodeGenerator.scala | 144 ++++++++
...unner.java => AbstractAsyncFunctionRunner.java} | 34 +-
...tionRunner.java => AbstractFunctionRunner.java} | 40 +-
.../operators/calc/async/AsyncFunctionRunner.java | 26 +-
.../join/lookup/AsyncLookupJoinRunner.java | 14 +-
.../operators/join/lookup/LookupJoinRunner.java | 14 +-
.../runtime/operators/ml/AsyncMLPredictRunner.java | 138 +++++++
.../runtime/operators/ml/MLPredictRunner.java | 73 ++++
14 files changed, 896 insertions(+), 433 deletions(-)
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/LookupCallContext.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/FunctionCallContext.java
similarity index 87%
rename from
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/LookupCallContext.java
rename to
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/FunctionCallContext.java
index ce655466c0f..562012f225b 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/LookupCallContext.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/FunctionCallContext.java
@@ -21,6 +21,7 @@ import org.apache.flink.annotation.Internal;
import org.apache.flink.table.catalog.DataTypeFactory;
import org.apache.flink.table.connector.source.LookupTableSource;
import org.apache.flink.table.functions.UserDefinedFunction;
+import org.apache.flink.table.ml.PredictRuntimeProvider;
import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.Constant;
import
org.apache.flink.table.planner.plan.utils.FunctionCallUtil.FunctionParam;
import org.apache.flink.table.types.DataType;
@@ -38,24 +39,27 @@ import static
org.apache.flink.table.planner.plan.utils.FunctionCallUtil.FieldRe
import static
org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getFieldTypes;
import static
org.apache.flink.table.types.utils.TypeConversions.fromLogicalToDataType;
-/** The {@link CallContext} of a {@link LookupTableSource} runtime function. */
+/**
+ * The {@link CallContext} of {@link LookupTableSource}, {@link
PredictRuntimeProvider} runtime
+ * function.
+ */
@Internal
-public class LookupCallContext extends AbstractSqlCallContext {
+public class FunctionCallContext extends AbstractSqlCallContext {
- private final List<FunctionParam> lookupKeys;
+ private final List<FunctionParam> params;
private final List<DataType> argumentDataTypes;
private final DataType outputDataType;
- public LookupCallContext(
+ public FunctionCallContext(
DataTypeFactory dataTypeFactory,
UserDefinedFunction function,
LogicalType inputType,
- List<FunctionParam> lookupKeys,
- LogicalType lookupType) {
+ List<FunctionParam> params,
+ LogicalType outputDataType) {
super(dataTypeFactory, function, generateInlineFunctionName(function),
false);
- this.lookupKeys = lookupKeys;
+ this.params = params;
this.argumentDataTypes =
new AbstractList<>() {
@Override
@@ -74,10 +78,10 @@ public class LookupCallContext extends
AbstractSqlCallContext {
@Override
public int size() {
- return lookupKeys.size();
+ return params.size();
}
};
- this.outputDataType = fromLogicalToDataType(lookupType);
+ this.outputDataType = fromLogicalToDataType(outputDataType);
}
@Override
@@ -118,6 +122,6 @@ public class LookupCallContext extends
AbstractSqlCallContext {
//
--------------------------------------------------------------------------------------------
private FunctionParam getKey(int pos) {
- return lookupKeys.get(pos);
+ return params.get(pos);
}
}
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java
index bbf4486c5ed..05b07458a10 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java
@@ -46,6 +46,7 @@ import org.apache.flink.table.legacy.sources.TableSource;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
import org.apache.flink.table.planner.codegen.FilterCodeGenerator;
+import org.apache.flink.table.planner.codegen.FunctionCallCodeGenerator;
import org.apache.flink.table.planner.codegen.LookupJoinCodeGenerator;
import org.apache.flink.table.planner.delegation.PlannerBase;
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
@@ -459,7 +460,7 @@ public abstract class CommonExecLookupJoin extends
ExecNodeBase<RowData> {
.mapToObj(allLookupKeys::get)
.collect(Collectors.toList());
-
LookupJoinCodeGenerator.GeneratedTableFunctionWithDataType<AsyncFunction<RowData,
Object>>
+
FunctionCallCodeGenerator.GeneratedTableFunctionWithDataType<AsyncFunction<RowData,
Object>>
generatedFuncWithType =
LookupJoinCodeGenerator.generateAsyncLookupFunction(
config,
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java
index 39a48443e0d..6d368618c91 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java
@@ -31,6 +31,7 @@ import
org.apache.flink.table.data.conversion.DataStructureConverters;
import org.apache.flink.table.functions.AsyncTableFunction;
import org.apache.flink.table.functions.UserDefinedFunctionHelper;
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
+import org.apache.flink.table.planner.codegen.FunctionCallCodeGenerator;
import org.apache.flink.table.planner.codegen.LookupJoinCodeGenerator;
import org.apache.flink.table.planner.delegation.PlannerBase;
import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
@@ -353,7 +354,7 @@ public class StreamExecDeltaJoin extends
ExecNodeBase<RowData>
.mapToObj(lookupKeys::get)
.collect(Collectors.toList());
-
LookupJoinCodeGenerator.GeneratedTableFunctionWithDataType<AsyncFunction<RowData,
Object>>
+
FunctionCallCodeGenerator.GeneratedTableFunctionWithDataType<AsyncFunction<RowData,
Object>>
lookupSideGeneratedFuncWithType =
LookupJoinCodeGenerator.generateAsyncLookupFunction(
config,
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecMLPredictTableFunction.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecMLPredictTableFunction.java
index 0ff85e23031..6d1fe8a8cee 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecMLPredictTableFunction.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecMLPredictTableFunction.java
@@ -31,8 +31,6 @@ import
org.apache.flink.streaming.api.operators.async.AsyncWaitOperatorFactory;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.catalog.DataTypeFactory;
import org.apache.flink.table.data.RowData;
-import org.apache.flink.table.data.conversion.DataStructureConverter;
-import org.apache.flink.table.data.conversion.DataStructureConverters;
import org.apache.flink.table.functions.AsyncPredictFunction;
import org.apache.flink.table.functions.PredictFunction;
import org.apache.flink.table.functions.UserDefinedFunction;
@@ -41,8 +39,8 @@ import org.apache.flink.table.ml.ModelProvider;
import org.apache.flink.table.ml.PredictRuntimeProvider;
import org.apache.flink.table.planner.calcite.FlinkContext;
import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
-import org.apache.flink.table.planner.codegen.FilterCodeGenerator;
-import org.apache.flink.table.planner.codegen.LookupJoinCodeGenerator;
+import org.apache.flink.table.planner.codegen.FunctionCallCodeGenerator;
+import org.apache.flink.table.planner.codegen.MLPredictCodeGenerator;
import org.apache.flink.table.planner.delegation.PlannerBase;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNode;
import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase;
@@ -55,18 +53,15 @@ import
org.apache.flink.table.planner.plan.nodes.exec.spec.MLPredictSpec;
import org.apache.flink.table.planner.plan.nodes.exec.spec.ModelSpec;
import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil;
import org.apache.flink.table.planner.plan.utils.FunctionCallUtil;
-import org.apache.flink.table.planner.utils.JavaScalaConversionUtil;
import org.apache.flink.table.runtime.collector.ListenableCollector;
-import org.apache.flink.table.runtime.collector.TableFunctionResultFuture;
import
org.apache.flink.table.runtime.functions.ml.ModelPredictRuntimeProviderContext;
import org.apache.flink.table.runtime.generated.GeneratedCollector;
import org.apache.flink.table.runtime.generated.GeneratedFunction;
-import org.apache.flink.table.runtime.generated.GeneratedResultFuture;
-import
org.apache.flink.table.runtime.operators.join.lookup.AsyncLookupJoinRunner;
-import org.apache.flink.table.runtime.operators.join.lookup.LookupJoinRunner;
-import org.apache.flink.table.runtime.typeutils.InternalSerializers;
+import org.apache.flink.table.runtime.operators.ml.AsyncMLPredictRunner;
+import org.apache.flink.table.runtime.operators.ml.MLPredictRunner;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.util.Preconditions;
import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator;
import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;
@@ -75,7 +70,6 @@ import javax.annotation.Nullable;
import java.util.Collections;
import java.util.List;
-import java.util.Optional;
/** Stream {@link ExecNode} for {@code ML_PREDICT}. */
@ExecNodeMetadata(
@@ -197,7 +191,7 @@ public class StreamExecMLPredictTableFunction extends
ExecNodeBase<RowData>
RowType resultRowType,
PredictFunction predictFunction) {
GeneratedFunction<FlatMapFunction<RowData, RowData>> generatedFetcher =
- LookupJoinCodeGenerator.generateSyncLookupFunction(
+ MLPredictCodeGenerator.generateSyncPredictFunction(
config,
classLoader,
dataTypeFactory,
@@ -206,25 +200,15 @@ public class StreamExecMLPredictTableFunction extends
ExecNodeBase<RowData>
resultRowType,
mlPredictSpec.getFeatures(),
predictFunction,
- "MLPredict",
+
modelSpec.getContextResolvedModel().getIdentifier().asSummaryString(),
config.get(PipelineOptions.OBJECT_REUSE));
GeneratedCollector<ListenableCollector<RowData>> generatedCollector =
- LookupJoinCodeGenerator.generateCollector(
+ MLPredictCodeGenerator.generateCollector(
new CodeGeneratorContext(config, classLoader),
inputRowType,
modelOutputType,
- (RowType) getOutputType(),
- JavaScalaConversionUtil.toScala(Optional.empty()),
- JavaScalaConversionUtil.toScala(Optional.empty()),
- true);
- LookupJoinRunner mlPredictRunner =
- new LookupJoinRunner(
- generatedFetcher,
- generatedCollector,
- FilterCodeGenerator.generateFilterCondition(
- config, classLoader, null, inputRowType),
- false,
- modelOutputType.getFieldCount());
+ (RowType) getOutputType());
+ MLPredictRunner mlPredictRunner = new
MLPredictRunner(generatedFetcher, generatedCollector);
SimpleOperatorFactory<RowData> operatorFactory =
SimpleOperatorFactory.of(new
ProcessOperator<>(mlPredictRunner));
return ExecNodeUtil.createOneInputTransformation(
@@ -246,9 +230,9 @@ public class StreamExecMLPredictTableFunction extends
ExecNodeBase<RowData>
RowType modelOutputType,
RowType resultRowType,
AsyncPredictFunction asyncPredictFunction) {
-
LookupJoinCodeGenerator.GeneratedTableFunctionWithDataType<AsyncFunction<RowData,
Object>>
+
FunctionCallCodeGenerator.GeneratedTableFunctionWithDataType<AsyncFunction<RowData,
Object>>
generatedFuncWithType =
- LookupJoinCodeGenerator.generateAsyncLookupFunction(
+ MLPredictCodeGenerator.generateAsyncPredictFunction(
config,
classLoader,
dataTypeFactory,
@@ -257,29 +241,14 @@ public class StreamExecMLPredictTableFunction extends
ExecNodeBase<RowData>
resultRowType,
mlPredictSpec.getFeatures(),
asyncPredictFunction,
- "AsyncMLPredict");
-
- GeneratedResultFuture<TableFunctionResultFuture<RowData>>
generatedResultFuture =
- LookupJoinCodeGenerator.generateTableAsyncCollector(
- config,
- classLoader,
- "TableFunctionResultFuture",
- inputRowType,
- modelOutputType,
- JavaScalaConversionUtil.toScala(Optional.empty()));
-
- DataStructureConverter<?, ?> fetcherConverter =
-
DataStructureConverters.getConverter(generatedFuncWithType.dataType());
+ modelSpec
+ .getContextResolvedModel()
+ .getIdentifier()
+ .asSummaryString());
AsyncFunction<RowData, RowData> asyncFunc =
- new AsyncLookupJoinRunner(
- generatedFuncWithType.tableFunc(),
- (DataStructureConverter<RowData, Object>)
fetcherConverter,
- generatedResultFuture,
- FilterCodeGenerator.generateFilterCondition(
- config, classLoader, null, inputRowType),
- InternalSerializers.create(modelOutputType),
- false,
- asyncOptions.asyncBufferCapacity);
+ new AsyncMLPredictRunner(
+ (GeneratedFunction) generatedFuncWithType.tableFunc(),
+
Preconditions.checkNotNull(asyncOptions).asyncBufferCapacity);
return ExecNodeUtil.createOneInputTransformation(
inputTransformation,
createTransformationMeta(ML_PREDICT_TRANSFORMATION, config),
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/FunctionCallCodeGenerator.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/FunctionCallCodeGenerator.scala
new file mode 100644
index 00000000000..438f14c21e3
--- /dev/null
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/FunctionCallCodeGenerator.scala
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.planner.codegen
+
+import org.apache.flink.api.common.functions.{FlatMapFunction, Function,
OpenContext}
+import org.apache.flink.configuration.ReadableConfig
+import org.apache.flink.streaming.api.functions.async.AsyncFunction
+import org.apache.flink.table.catalog.DataTypeFactory
+import org.apache.flink.table.data.{GenericRowData, RowData}
+import org.apache.flink.table.data.utils.JoinedRowData
+import org.apache.flink.table.functions.{AsyncTableFunction, TableFunction,
UserDefinedFunction, UserDefinedFunctionHelper}
+import
org.apache.flink.table.planner.codegen.CodeGenUtils.{boxedTypeTermForType,
className, newName, DEFAULT_COLLECTOR_TERM, DEFAULT_INPUT1_TERM,
DEFAULT_INPUT2_TERM}
+import
org.apache.flink.table.planner.codegen.GenerateUtils.{generateInputAccess,
generateLiteral}
+import org.apache.flink.table.planner.delegation.PlannerBase
+import org.apache.flink.table.planner.functions.inference.FunctionCallContext
+import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.{Constant,
FieldRef, FunctionParam}
+import org.apache.flink.table.planner.plan.utils.RexLiteralUtil
+import org.apache.flink.table.runtime.collector.ListenableCollector
+import
org.apache.flink.table.runtime.collector.ListenableCollector.CollectListener
+import org.apache.flink.table.runtime.generated.{GeneratedCollector,
GeneratedFunction}
+import org.apache.flink.table.types.DataType
+import org.apache.flink.table.types.logical.{LogicalType, RowType}
+import org.apache.flink.util.Collector
+
+import org.apache.calcite.rex.RexNode
+
+import java.util
+
+import scala.collection.JavaConverters._
+
+object FunctionCallCodeGenerator {
+
+ case class GeneratedTableFunctionWithDataType[F <: Function](
+ tableFunc: GeneratedFunction[F],
+ dataType: DataType)
+
+ /** Generates a sync function ([[TableFunction]]) call. */
+ def generateSyncFunctionCall(
+ tableConfig: ReadableConfig,
+ classLoader: ClassLoader,
+ dataTypeFactory: DataTypeFactory,
+ inputType: LogicalType,
+ functionOutputType: LogicalType,
+ collectorOutputType: LogicalType,
+ parameters: util.List[FunctionParam],
+ syncFunctionDefinition: TableFunction[_],
+ inferCall: (
+ CodeGeneratorContext,
+ FunctionCallContext,
+ UserDefinedFunction,
+ Seq[GeneratedExpression]) => (GeneratedExpression, DataType),
+ functionName: String,
+ generateClassName: String,
+ fieldCopy: Boolean):
GeneratedTableFunctionWithDataType[FlatMapFunction[RowData, RowData]] = {
+
+ val bodyCode: GeneratedExpression => String = call => {
+ val resultCollectorTerm = call.resultTerm
+ s"""
+ |$resultCollectorTerm.setCollector($DEFAULT_COLLECTOR_TERM);
+ |${call.code}
+ |""".stripMargin
+ }
+
+ generateFunctionCall(
+ classOf[FlatMapFunction[RowData, RowData]],
+ tableConfig,
+ classLoader,
+ dataTypeFactory,
+ inputType,
+ functionOutputType,
+ collectorOutputType,
+ parameters,
+ syncFunctionDefinition,
+ inferCall,
+ functionName,
+ generateClassName,
+ fieldCopy,
+ bodyCode
+ )
+ }
+
+ /** Generates an async function ([[AsyncTableFunction]]) call. */
+ def generateAsyncFunctionCall(
+ tableConfig: ReadableConfig,
+ classLoader: ClassLoader,
+ dataTypeFactory: DataTypeFactory,
+ inputType: LogicalType,
+ functionOutputType: LogicalType,
+ collectorOutputType: LogicalType,
+ parameters: util.List[FunctionParam],
+ asyncFunctionDefinition: AsyncTableFunction[_],
+ generateCallWithDataType: (
+ CodeGeneratorContext,
+ FunctionCallContext,
+ UserDefinedFunction,
+ Seq[GeneratedExpression]) => (GeneratedExpression, DataType),
+ functionName: String,
+ generateClassName: String
+ ): GeneratedTableFunctionWithDataType[AsyncFunction[RowData, AnyRef]] = {
+ generateFunctionCall(
+ classOf[AsyncFunction[RowData, AnyRef]],
+ tableConfig,
+ classLoader,
+ dataTypeFactory,
+ inputType,
+ functionOutputType,
+ collectorOutputType,
+ parameters,
+ asyncFunctionDefinition,
+ generateCallWithDataType,
+ functionName,
+ generateClassName,
+ fieldCopy = true,
+ _.code
+ )
+ }
+
+ private def generateFunctionCall[F <: Function](
+ generatedClass: Class[F],
+ tableConfig: ReadableConfig,
+ classLoader: ClassLoader,
+ dataTypeFactory: DataTypeFactory,
+ inputType: LogicalType,
+ functionOutputType: LogicalType,
+ collectorOutputType: LogicalType,
+ parameters: util.List[FunctionParam],
+ functionDefinition: UserDefinedFunction,
+ generateCallWithDataType: (
+ CodeGeneratorContext,
+ FunctionCallContext,
+ UserDefinedFunction,
+ Seq[GeneratedExpression]) => (GeneratedExpression, DataType),
+ functionName: String,
+ generateClassName: String,
+ fieldCopy: Boolean,
+ bodyCode: GeneratedExpression => String):
GeneratedTableFunctionWithDataType[F] = {
+
+ val callContext =
+ new FunctionCallContext(
+ dataTypeFactory,
+ functionDefinition,
+ inputType,
+ parameters,
+ functionOutputType)
+
+ // create the final UDF for runtime
+ val udf = UserDefinedFunctionHelper.createSpecializedFunction(
+ functionName,
+ functionDefinition,
+ callContext,
+ classOf[PlannerBase].getClassLoader,
+ tableConfig,
+ // no need to support expression evaluation at this point
+ null
+ )
+
+ val ctx = new CodeGeneratorContext(tableConfig, classLoader)
+ val operands = prepareOperands(ctx, inputType, parameters, fieldCopy)
+
+ val callWithDataType: (GeneratedExpression, DataType) =
+ generateCallWithDataType(ctx, callContext, udf, operands)
+
+ val function = FunctionCodeGenerator.generateFunction(
+ ctx,
+ generateClassName,
+ generatedClass,
+ bodyCode(callWithDataType._1),
+ collectorOutputType,
+ inputType)
+
+ GeneratedTableFunctionWithDataType(function, callWithDataType._2)
+ }
+
+ private def prepareOperands(
+ ctx: CodeGeneratorContext,
+ inputType: LogicalType,
+ parameters: util.List[FunctionParam],
+ fieldCopy: Boolean): Seq[GeneratedExpression] = {
+
+ parameters.asScala
+ .map {
+ case constantKey: Constant =>
+ val res = RexLiteralUtil.toFlinkInternalValue(constantKey.literal)
+ generateLiteral(ctx, res.f0, res.f1)
+ case fieldKey: FieldRef =>
+ generateInputAccess(
+ ctx,
+ inputType,
+ DEFAULT_INPUT1_TERM,
+ fieldKey.index,
+ nullableInput = false,
+ fieldCopy)
+ case _ =>
+ throw new CodeGenException("Invalid parameters.")
+ }
+ }
+
+ /**
+ * Generates collector for join ([[Collector]])
+ *
+ * Differs from CommonCorrelate.generateCollector which has no real
condition because of
+ * FLINK-7865, here we should deal with outer join type when real conditions
filtered result.
+ */
+ def generateCollector(
+ ctx: CodeGeneratorContext,
+ inputRowType: RowType,
+ rightRowType: RowType,
+ resultRowType: RowType,
+ condition: Option[RexNode],
+ pojoFieldMapping: Option[Array[Int]],
+ retainHeader: Boolean = true):
GeneratedCollector[ListenableCollector[RowData]] = {
+
+ val inputTerm = DEFAULT_INPUT1_TERM
+ val rightInputTerm = DEFAULT_INPUT2_TERM
+
+ val exprGenerator = new ExprCodeGenerator(ctx, nullableInput = false)
+ .bindInput(rightRowType, inputTerm = rightInputTerm, inputFieldMapping =
pojoFieldMapping)
+
+ val rightResultExpr =
+ exprGenerator.generateConverterResultExpression(rightRowType,
classOf[GenericRowData])
+
+ val joinedRowTerm = CodeGenUtils.newName(ctx, "joinedRow")
+ ctx.addReusableOutputRecord(resultRowType, classOf[JoinedRowData],
joinedRowTerm)
+
+ val header = if (retainHeader) {
+ s"$joinedRowTerm.setRowKind($inputTerm.getRowKind());"
+ } else {
+ ""
+ }
+
+ val body =
+ s"""
+ |${rightResultExpr.code}
+ |$joinedRowTerm.replace($inputTerm, ${rightResultExpr.resultTerm});
+ |$header
+ |outputResult($joinedRowTerm);
+ """.stripMargin
+
+ val collectorCode = if (condition.isEmpty) {
+ body
+ } else {
+
+ val filterGenerator = new ExprCodeGenerator(ctx, nullableInput = false)
+ .bindInput(inputRowType, inputTerm)
+ .bindSecondInput(rightRowType, rightInputTerm, pojoFieldMapping)
+ val filterCondition = filterGenerator.generateExpression(condition.get)
+
+ s"""
+ |${filterCondition.code}
+ |if (${filterCondition.resultTerm}) {
+ | $body
+ |}
+ |""".stripMargin
+ }
+
+ generateTableFunctionCollectorForJoinTable(
+ ctx,
+ "JoinTableFuncCollector",
+ collectorCode,
+ inputRowType,
+ rightRowType,
+ inputTerm = inputTerm,
+ collectedTerm = rightInputTerm)
+ }
+
+ /**
+ * The only differences against
CollectorCodeGenerator.generateTableFunctionCollector is
+ * "super.collect" call is binding with collect join row in "body" code
+ */
+ private def generateTableFunctionCollectorForJoinTable(
+ ctx: CodeGeneratorContext,
+ name: String,
+ bodyCode: String,
+ inputType: RowType,
+ collectedType: RowType,
+ inputTerm: String = DEFAULT_INPUT1_TERM,
+ collectedTerm: String = DEFAULT_INPUT2_TERM)
+ : GeneratedCollector[ListenableCollector[RowData]] = {
+
+ val funcName = newName(ctx, name)
+ val input1TypeClass = boxedTypeTermForType(inputType)
+ val input2TypeClass = boxedTypeTermForType(collectedType)
+
+ val funcCode =
+ s"""
+ public class $funcName extends
${classOf[ListenableCollector[_]].getCanonicalName} {
+
+ ${ctx.reuseMemberCode()}
+
+ public $funcName(Object[] references) throws Exception {
+ ${ctx.reuseInitCode()}
+ }
+
+ @Override
+ public void open(${className[OpenContext]} openContext) throws
Exception {
+ ${ctx.reuseOpenCode()}
+ }
+
+ @Override
+ public void collect(Object record) throws Exception {
+ $input1TypeClass $inputTerm = ($input1TypeClass) getInput();
+ $input2TypeClass $collectedTerm = ($input2TypeClass) record;
+
+ // callback only when collectListener exists, equivalent to:
+ // getCollectListener().ifPresent(
+ // listener -> ((CollectListener) listener).onCollect(record));
+ // TODO we should update code splitter's grammar file to accept
lambda expressions.
+
+ if (getCollectListener().isPresent()) {
+ ((${classOf[CollectListener[_]].getCanonicalName})
getCollectListener().get())
+ .onCollect(record);
+ }
+
+ ${ctx.reuseLocalVariableCode()}
+ ${ctx.reuseInputUnboxingCode()}
+ ${ctx.reusePerRecordCode()}
+ $bodyCode
+ }
+
+ @Override
+ public void close() throws Exception {
+ ${ctx.reuseCloseCode()}
+ }
+ }
+ """.stripMargin
+
+ new GeneratedCollector(funcName, funcCode, ctx.references.toArray,
ctx.tableConfig)
+ }
+}
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala
index e61783b6844..2f3ab1708b8 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala
@@ -17,7 +17,7 @@
*/
package org.apache.flink.table.planner.codegen
-import org.apache.flink.api.common.functions.{FlatMapFunction, Function,
OpenContext}
+import org.apache.flink.api.common.functions.{FlatMapFunction, OpenContext}
import org.apache.flink.configuration.ReadableConfig
import org.apache.flink.streaming.api.functions.async.AsyncFunction
import org.apache.flink.table.api.ValidationException
@@ -25,17 +25,15 @@ import org.apache.flink.table.catalog.DataTypeFactory
import org.apache.flink.table.connector.source.{LookupTableSource,
ScanTableSource}
import org.apache.flink.table.data.{GenericRowData, RowData}
import org.apache.flink.table.data.utils.JoinedRowData
-import org.apache.flink.table.functions.{AsyncLookupFunction,
AsyncPredictFunction, AsyncTableFunction, LookupFunction, PredictFunction,
TableFunction, UserDefinedFunction, UserDefinedFunctionHelper}
+import org.apache.flink.table.functions.{AsyncLookupFunction,
AsyncTableFunction, LookupFunction, TableFunction, UserDefinedFunction}
import org.apache.flink.table.planner.calcite.FlinkTypeFactory
import org.apache.flink.table.planner.codegen.CodeGenUtils._
-import org.apache.flink.table.planner.codegen.GenerateUtils._
+import
org.apache.flink.table.planner.codegen.FunctionCallCodeGenerator.GeneratedTableFunctionWithDataType
import org.apache.flink.table.planner.codegen.Indenter.toISC
import org.apache.flink.table.planner.codegen.calls.BridgingFunctionGenUtil
import
org.apache.flink.table.planner.codegen.calls.BridgingFunctionGenUtil.verifyFunctionAwareImplementation
-import org.apache.flink.table.planner.delegation.PlannerBase
-import org.apache.flink.table.planner.functions.inference.LookupCallContext
-import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.{Constant,
FieldRef, FunctionParam}
-import org.apache.flink.table.planner.plan.utils.RexLiteralUtil
+import org.apache.flink.table.planner.functions.inference.FunctionCallContext
+import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.FunctionParam
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala
import org.apache.flink.table.runtime.collector.{ListenableCollector,
TableFunctionResultFuture}
import
org.apache.flink.table.runtime.collector.ListenableCollector.CollectListener
@@ -57,10 +55,6 @@ import scala.collection.JavaConverters._
object LookupJoinCodeGenerator {
- case class GeneratedTableFunctionWithDataType[F <: Function](
- tableFunc: GeneratedFunction[F],
- dataType: DataType)
-
private val ARRAY_LIST = className[util.ArrayList[_]]
/** Generates a lookup function ([[TableFunction]]) */
@@ -76,29 +70,26 @@ object LookupJoinCodeGenerator {
functionName: String,
fieldCopy: Boolean): GeneratedFunction[FlatMapFunction[RowData,
RowData]] = {
- val bodyCode: GeneratedExpression => String = call => {
- val resultCollectorTerm = call.resultTerm
- s"""
- |$resultCollectorTerm.setCollector($DEFAULT_COLLECTOR_TERM);
- |${call.code}
- |""".stripMargin
- }
-
- generateLookupFunction(
- classOf[FlatMapFunction[RowData, RowData]],
- tableConfig,
- classLoader,
- dataTypeFactory,
- inputType,
- tableSourceType,
- returnType,
- lookupKeys,
- classOf[TableFunction[_]],
- syncLookupFunction,
- functionName,
- fieldCopy,
- bodyCode
- ).tableFunc
+ FunctionCallCodeGenerator
+ .generateSyncFunctionCall(
+ tableConfig,
+ classLoader,
+ dataTypeFactory,
+ inputType,
+ tableSourceType,
+ returnType,
+ lookupKeys,
+ syncLookupFunction,
+ generateCallWithDataType(
+ dataTypeFactory,
+ functionName,
+ tableSourceType,
+ classOf[TableFunction[_]]),
+ functionName,
+ "LookupFunction",
+ fieldCopy
+ )
+ .tableFunc
}
/** Generates a async lookup function ([[AsyncTableFunction]]) */
@@ -112,9 +103,7 @@ object LookupJoinCodeGenerator {
lookupKeys: util.List[FunctionParam],
asyncLookupFunction: AsyncTableFunction[_],
functionName: String):
GeneratedTableFunctionWithDataType[AsyncFunction[RowData, AnyRef]] = {
-
- generateLookupFunction(
- classOf[AsyncFunction[RowData, AnyRef]],
+ FunctionCallCodeGenerator.generateAsyncFunctionCall(
tableConfig,
classLoader,
dataTypeFactory,
@@ -122,98 +111,66 @@ object LookupJoinCodeGenerator {
tableSourceType,
returnType,
lookupKeys,
- classOf[AsyncTableFunction[_]],
asyncLookupFunction,
+ generateCallWithDataType(
+ dataTypeFactory,
+ functionName,
+ tableSourceType,
+ classOf[AsyncTableFunction[_]]),
functionName,
- fieldCopy = true, // always copy input field because of async buffer
- _.code
+ "AsyncLookupFunction"
)
}
- private def generateLookupFunction[F <: Function](
- generatedClass: Class[F],
- tableConfig: ReadableConfig,
- classLoader: ClassLoader,
+ private def generateCallWithDataType(
dataTypeFactory: DataTypeFactory,
- inputType: LogicalType,
- tableSourceType: LogicalType,
- returnType: LogicalType,
- lookupKeys: util.List[FunctionParam],
- lookupFunctionBase: Class[_],
- lookupFunction: UserDefinedFunction,
functionName: String,
- fieldCopy: Boolean,
- bodyCode: GeneratedExpression => String):
GeneratedTableFunctionWithDataType[F] = {
-
- val callContext =
- new LookupCallContext(dataTypeFactory, lookupFunction, inputType,
lookupKeys, tableSourceType)
-
- // create the final UDF for runtime
- val udf = UserDefinedFunctionHelper.createSpecializedFunction(
- functionName,
- lookupFunction,
- callContext,
- classOf[PlannerBase].getClassLoader,
- tableConfig,
- // no need to support expression evaluation at this point
- null)
-
- val inference =
- createLookupTypeInference(dataTypeFactory, callContext,
lookupFunctionBase, udf, functionName)
-
- val ctx = new CodeGeneratorContext(tableConfig, classLoader)
- val operands = prepareOperands(ctx, inputType, lookupKeys, fieldCopy)
-
- // TODO: filter all records when there are any nulls on the join key,
because
- // "IS NOT DISTINCT FROM" is not supported yet.
- // Note: AsyncPredictFunction or PredictFunction does not use Lookup
Syntax.
- val skipIfArgsNull = !lookupFunction.isInstanceOf[PredictFunction] &&
!lookupFunction
- .isInstanceOf[AsyncPredictFunction]
-
- val callWithDataType =
BridgingFunctionGenUtil.generateFunctionAwareCallWithDataType(
- ctx,
- operands,
- tableSourceType,
- inference,
- callContext,
- udf,
- functionName,
- skipIfArgsNull = skipIfArgsNull
- )
-
- val function = FunctionCodeGenerator.generateFunction(
- ctx,
- "LookupFunction",
- generatedClass,
- bodyCode(callWithDataType._1),
- returnType,
- inputType)
-
- GeneratedTableFunctionWithDataType(function, callWithDataType._2)
- }
-
- private def prepareOperands(
+ tableSourceType: LogicalType,
+ baseClass: Class[_]
+ ) = (
ctx: CodeGeneratorContext,
- inputType: LogicalType,
- lookupKeys: util.List[FunctionParam],
- fieldCopy: Boolean): Seq[GeneratedExpression] = {
+ callContext: FunctionCallContext,
+ udf: UserDefinedFunction,
+ operands: Seq[GeneratedExpression]) => {
+ def inferCallWithDataType(
+ ctx: CodeGeneratorContext,
+ callContext: FunctionCallContext,
+ udf: UserDefinedFunction,
+ operands: Seq[GeneratedExpression],
+ legacy: Boolean,
+ e: Exception = null): (GeneratedExpression, DataType) = {
+ val inference = createLookupTypeInference(
+ dataTypeFactory,
+ callContext,
+ baseClass,
+ udf,
+ functionName,
+ legacy,
+ e)
+
+ // TODO: filter all records when there is any nulls on the join key,
because
+ // "IS NOT DISTINCT FROM" is not supported yet.
+ val callWithDataType =
BridgingFunctionGenUtil.generateFunctionAwareCallWithDataType(
+ ctx,
+ operands,
+ tableSourceType,
+ inference,
+ callContext,
+ udf,
+ functionName,
+ skipIfArgsNull = true
+ )
+ callWithDataType
+ }
- lookupKeys.asScala
- .map {
- case constantKey: Constant =>
- val res = RexLiteralUtil.toFlinkInternalValue(constantKey.literal)
- generateLiteral(ctx, res.f0, res.f1)
- case fieldKey: FieldRef =>
- generateInputAccess(
- ctx,
- inputType,
- DEFAULT_INPUT1_TERM,
- fieldKey.index,
- nullableInput = false,
- fieldCopy)
- case _ =>
- throw new CodeGenException("Invalid lookup key.")
- }
+ try {
+ // user provided type inference has precedence
+ // this ensures that all functions work in the same way
+ inferCallWithDataType(ctx, callContext, udf, operands, legacy = false)
+ } catch {
+ case e: Exception =>
+ inferCallWithDataType(ctx, callContext, udf, operands, legacy = true,
e)
+ }
}
/**
@@ -225,66 +182,58 @@ object LookupJoinCodeGenerator {
*/
private def createLookupTypeInference(
dataTypeFactory: DataTypeFactory,
- callContext: LookupCallContext,
+ callContext: FunctionCallContext,
baseClass: Class[_],
udf: UserDefinedFunction,
- functionName: String): TypeInference = {
+ functionName: String,
+ legacy: Boolean,
+ e: Exception): TypeInference = {
- try {
+ if (!legacy) {
// user provided type inference has precedence
// this ensures that all functions work in the same way
udf.getTypeInference(dataTypeFactory)
- } catch {
- case e: Exception =>
- // for convenience, we assume internal or default external data
structures
- // of expected logical types
- val defaultArgDataTypes = callContext.getArgumentDataTypes.asScala
- val defaultOutputDataType = callContext.getOutputDataType.get()
-
- val outputClass =
- if (
- udf.isInstanceOf[LookupFunction] ||
udf.isInstanceOf[AsyncLookupFunction] || udf
- .isInstanceOf[PredictFunction] ||
udf.isInstanceOf[AsyncPredictFunction]
- ) {
- Some(classOf[RowData])
- } else {
- toScala(extractSimpleGeneric(baseClass, udf.getClass, 0))
- }
- val (argDataTypes, outputDataType) = outputClass match {
- case Some(c) if c == classOf[Row] =>
- (defaultArgDataTypes, defaultOutputDataType)
- case Some(c) if c == classOf[RowData] =>
- val internalArgDataTypes = defaultArgDataTypes
- .map(dt => transform(dt, TypeTransformations.TO_INTERNAL_CLASS))
- val internalOutputDataType =
- transform(defaultOutputDataType,
TypeTransformations.TO_INTERNAL_CLASS)
- (internalArgDataTypes, internalOutputDataType)
- case _ =>
- throw new ValidationException(
- s"Could not determine a type inference for lookup function
'$functionName'. " +
- s"Lookup functions support regular type inference. However,
for convenience, the " +
- s"output class can simply be a ${classOf[Row].getSimpleName}
or " +
- s"${classOf[RowData].getSimpleName} class in which case the
input and output " +
- s"types are derived from the table's schema with default
conversion.",
- e)
+ } else {
+ // for convenience, we assume internal or default external data
structures
+ // of expected logical types
+ val defaultArgDataTypes = callContext.getArgumentDataTypes.asScala
+ val defaultOutputDataType = callContext.getOutputDataType.get()
+
+ val outputClass =
+ if (udf.isInstanceOf[LookupFunction] ||
udf.isInstanceOf[AsyncLookupFunction]) {
+ Some(classOf[RowData])
+ } else {
+ toScala(extractSimpleGeneric(baseClass, udf.getClass, 0))
}
+ val (argDataTypes, outputDataType) = outputClass match {
+ case Some(c) if c == classOf[Row] =>
+ (defaultArgDataTypes, defaultOutputDataType)
+ case Some(c) if c == classOf[RowData] =>
+ val internalArgDataTypes = defaultArgDataTypes
+ .map(dt => transform(dt, TypeTransformations.TO_INTERNAL_CLASS))
+ val internalOutputDataType =
+ transform(defaultOutputDataType,
TypeTransformations.TO_INTERNAL_CLASS)
+ (internalArgDataTypes, internalOutputDataType)
+ case _ =>
+ throw new ValidationException(
+ s"Could not determine a type inference for lookup function
'$functionName'. " +
+ s"Lookup functions support regular type inference. However, for
convenience, the " +
+ s"output class can simply be a ${classOf[Row].getSimpleName} or
" +
+ s"${classOf[RowData].getSimpleName} class in which case the
input and output " +
+ s"types are derived from the table's schema with default
conversion.",
+ e)
+ }
- verifyFunctionAwareImplementation(argDataTypes, outputDataType, udf,
functionName)
+ verifyFunctionAwareImplementation(argDataTypes, outputDataType, udf,
functionName)
- TypeInference
- .newBuilder()
- .typedArguments(argDataTypes.asJava)
- .outputTypeStrategy(TypeStrategies.explicit(outputDataType))
- .build()
+ TypeInference
+ .newBuilder()
+ .typedArguments(argDataTypes.asJava)
+ .outputTypeStrategy(TypeStrategies.explicit(outputDataType))
+ .build()
}
}
- /**
- * Generates collector for temporal join ([[Collector]])
- *
- * Differs from CommonCorrelate.generateCollector which has no real
condition because of
- * FLINK-7865, here we should deal with outer join type when real conditions
filtered result.
- */
def generateCollector(
ctx: CodeGeneratorContext,
inputRowType: RowType,
@@ -293,122 +242,14 @@ object LookupJoinCodeGenerator {
condition: Option[RexNode],
pojoFieldMapping: Option[Array[Int]],
retainHeader: Boolean = true):
GeneratedCollector[ListenableCollector[RowData]] = {
-
- val inputTerm = DEFAULT_INPUT1_TERM
- val rightInputTerm = DEFAULT_INPUT2_TERM
-
- val exprGenerator = new ExprCodeGenerator(ctx, nullableInput = false)
- .bindInput(rightRowType, inputTerm = rightInputTerm, inputFieldMapping =
pojoFieldMapping)
-
- val rightResultExpr =
- exprGenerator.generateConverterResultExpression(rightRowType,
classOf[GenericRowData])
-
- val joinedRowTerm = CodeGenUtils.newName(ctx, "joinedRow")
- ctx.addReusableOutputRecord(resultRowType, classOf[JoinedRowData],
joinedRowTerm)
-
- val header = if (retainHeader) {
- s"$joinedRowTerm.setRowKind($inputTerm.getRowKind());"
- } else {
- ""
- }
-
- val body =
- s"""
- |${rightResultExpr.code}
- |$joinedRowTerm.replace($inputTerm, ${rightResultExpr.resultTerm});
- |$header
- |outputResult($joinedRowTerm);
- """.stripMargin
-
- val collectorCode = if (condition.isEmpty) {
- body
- } else {
-
- val filterGenerator = new ExprCodeGenerator(ctx, nullableInput = false)
- .bindInput(inputRowType, inputTerm)
- .bindSecondInput(rightRowType, rightInputTerm, pojoFieldMapping)
- val filterCondition = filterGenerator.generateExpression(condition.get)
-
- s"""
- |${filterCondition.code}
- |if (${filterCondition.resultTerm}) {
- | $body
- |}
- |""".stripMargin
- }
-
- generateTableFunctionCollectorForJoinTable(
+ FunctionCallCodeGenerator.generateCollector(
ctx,
- "JoinTableFuncCollector",
- collectorCode,
inputRowType,
rightRowType,
- inputTerm = inputTerm,
- collectedTerm = rightInputTerm)
- }
-
- /**
- * The only differences against
CollectorCodeGenerator.generateTableFunctionCollector is
- * "super.collect" call is binding with collect join row in "body" code
- */
- private def generateTableFunctionCollectorForJoinTable(
- ctx: CodeGeneratorContext,
- name: String,
- bodyCode: String,
- inputType: RowType,
- collectedType: RowType,
- inputTerm: String = DEFAULT_INPUT1_TERM,
- collectedTerm: String = DEFAULT_INPUT2_TERM)
- : GeneratedCollector[ListenableCollector[RowData]] = {
-
- val funcName = newName(ctx, name)
- val input1TypeClass = boxedTypeTermForType(inputType)
- val input2TypeClass = boxedTypeTermForType(collectedType)
-
- val funcCode =
- s"""
- public class $funcName extends
${classOf[ListenableCollector[_]].getCanonicalName} {
-
- ${ctx.reuseMemberCode()}
-
- public $funcName(Object[] references) throws Exception {
- ${ctx.reuseInitCode()}
- }
-
- @Override
- public void open(${className[OpenContext]} openContext) throws
Exception {
- ${ctx.reuseOpenCode()}
- }
-
- @Override
- public void collect(Object record) throws Exception {
- $input1TypeClass $inputTerm = ($input1TypeClass) getInput();
- $input2TypeClass $collectedTerm = ($input2TypeClass) record;
-
- // callback only when collectListener exists, equivalent to:
- // getCollectListener().ifPresent(
- // listener -> ((CollectListener) listener).onCollect(record));
- // TODO we should update code splitter's grammar file to accept
lambda expressions.
-
- if (getCollectListener().isPresent()) {
- ((${classOf[CollectListener[_]].getCanonicalName})
getCollectListener().get())
- .onCollect(record);
- }
-
- ${ctx.reuseLocalVariableCode()}
- ${ctx.reuseInputUnboxingCode()}
- ${ctx.reusePerRecordCode()}
- $bodyCode
- }
-
- @Override
- public void close() throws Exception {
- ${ctx.reuseCloseCode()}
- }
- }
- """.stripMargin
-
- new GeneratedCollector(funcName, funcCode, ctx.references.toArray,
ctx.tableConfig)
+ resultRowType,
+ condition,
+ pojoFieldMapping,
+ retainHeader)
}
/**
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/MLPredictCodeGenerator.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/MLPredictCodeGenerator.scala
new file mode 100644
index 00000000000..68d322f8275
--- /dev/null
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/MLPredictCodeGenerator.scala
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.planner.codegen
+
+import org.apache.flink.api.common.functions.FlatMapFunction
+import org.apache.flink.configuration.ReadableConfig
+import org.apache.flink.streaming.api.functions.async.AsyncFunction
+import org.apache.flink.table.catalog.DataTypeFactory
+import org.apache.flink.table.data.RowData
+import org.apache.flink.table.functions.{AsyncTableFunction, TableFunction,
UserDefinedFunction}
+import
org.apache.flink.table.planner.codegen.FunctionCallCodeGenerator.GeneratedTableFunctionWithDataType
+import org.apache.flink.table.planner.codegen.calls.BridgingFunctionGenUtil
+import org.apache.flink.table.planner.functions.inference.FunctionCallContext
+import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.FunctionParam
+import org.apache.flink.table.runtime.collector.ListenableCollector
+import org.apache.flink.table.runtime.generated.{GeneratedCollector,
GeneratedFunction}
+import org.apache.flink.table.types.inference.{TypeInference, TypeStrategies,
TypeTransformations}
+import org.apache.flink.table.types.logical.{LogicalType, RowType}
+import org.apache.flink.table.types.utils.DataTypeUtils.transform
+
+import java.util
+
+import scala.collection.JavaConverters._
+
+object MLPredictCodeGenerator {
+
+ /** Generates a predict function ([[TableFunction]]) */
+ def generateSyncPredictFunction(
+ tableConfig: ReadableConfig,
+ classLoader: ClassLoader,
+ dataTypeFactory: DataTypeFactory,
+ inputType: LogicalType,
+ predictFunctionOutputType: LogicalType,
+ collectorOutputType: LogicalType,
+ features: util.List[FunctionParam],
+ syncPredictFunction: TableFunction[_],
+ functionName: String,
+ fieldCopy: Boolean
+ ): GeneratedFunction[FlatMapFunction[RowData, RowData]] = {
+ FunctionCallCodeGenerator
+ .generateSyncFunctionCall(
+ tableConfig,
+ classLoader,
+ dataTypeFactory,
+ inputType,
+ predictFunctionOutputType,
+ collectorOutputType,
+ features,
+ syncPredictFunction,
+ generateCallWithDataType(functionName, predictFunctionOutputType),
+ functionName,
+ "PredictFunction",
+ fieldCopy
+ )
+ .tableFunc
+ }
+
+ /** Generates a async predict function ([[AsyncTableFunction]]) */
+ def generateAsyncPredictFunction(
+ tableConfig: ReadableConfig,
+ classLoader: ClassLoader,
+ dataTypeFactory: DataTypeFactory,
+ inputType: LogicalType,
+ predictFunctionOutputType: LogicalType,
+ collectorOutputType: LogicalType,
+ features: util.List[FunctionParam],
+ asyncPredictFunction: AsyncTableFunction[_],
+ functionName: String):
GeneratedTableFunctionWithDataType[AsyncFunction[RowData, AnyRef]] = {
+ FunctionCallCodeGenerator.generateAsyncFunctionCall(
+ tableConfig,
+ classLoader,
+ dataTypeFactory,
+ inputType,
+ predictFunctionOutputType,
+ collectorOutputType,
+ features,
+ asyncPredictFunction,
+ generateCallWithDataType(functionName, predictFunctionOutputType),
+ functionName,
+ "AsyncPredictFunction"
+ )
+ }
+
+ /** Generate a collector to collect to join the input row and predicted
results. */
+ def generateCollector(
+ ctx: CodeGeneratorContext,
+ inputRowType: RowType,
+ predictFunctionOutputType: RowType,
+ collectorOutputType: RowType
+ ): GeneratedCollector[ListenableCollector[RowData]] = {
+ FunctionCallCodeGenerator.generateCollector(
+ ctx,
+ inputRowType,
+ predictFunctionOutputType,
+ collectorOutputType,
+ Option.empty,
+ Option.empty
+ )
+ }
+
+ private def generateCallWithDataType(
+ functionName: String,
+ modelOutputType: LogicalType
+ ) = (
+ ctx: CodeGeneratorContext,
+ callContext: FunctionCallContext,
+ udf: UserDefinedFunction,
+ operands: Seq[GeneratedExpression]) => {
+ val inference = TypeInference
+ .newBuilder()
+ .typedArguments(
+ callContext.getArgumentDataTypes.asScala
+ .map(dt => transform(dt, TypeTransformations.TO_INTERNAL_CLASS))
+ .asJava)
+ .outputTypeStrategy(TypeStrategies.explicit(
+ transform(callContext.getOutputDataType.get(),
TypeTransformations.TO_INTERNAL_CLASS)))
+ .build()
+ BridgingFunctionGenUtil.generateFunctionAwareCallWithDataType(
+ ctx,
+ operands,
+ modelOutputType,
+ inference,
+ callContext,
+ udf,
+ functionName,
+ skipIfArgsNull = false
+ )
+ }
+}
diff --git
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/calc/async/AsyncFunctionRunner.java
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/AbstractAsyncFunctionRunner.java
similarity index 62%
copy from
flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/calc/async/AsyncFunctionRunner.java
copy to
flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/AbstractAsyncFunctionRunner.java
index 456f07bcd96..10e4defcfb4 100644
---
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/calc/async/AsyncFunctionRunner.java
+++
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/AbstractAsyncFunctionRunner.java
@@ -16,30 +16,29 @@
* limitations under the License.
*/
-package org.apache.flink.table.runtime.operators.calc.async;
+package org.apache.flink.table.runtime.operators;
import org.apache.flink.api.common.functions.OpenContext;
import org.apache.flink.api.common.functions.util.FunctionUtils;
import org.apache.flink.streaming.api.functions.async.AsyncFunction;
-import org.apache.flink.streaming.api.functions.async.ResultFuture;
import org.apache.flink.streaming.api.functions.async.RichAsyncFunction;
import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.functions.AsyncLookupFunction;
+import org.apache.flink.table.functions.AsyncPredictFunction;
import org.apache.flink.table.runtime.generated.GeneratedFunction;
/**
- * Async function runner for {@link
org.apache.flink.table.functions.AsyncScalarFunction}, which
- * takes the generated function, instantiates it, and then calls its lifecycle
methods.
+ * Base function runner for specialized table function, e.g. {@link
AsyncLookupFunction} or {@link
+ * AsyncPredictFunction}.
*/
-public class AsyncFunctionRunner extends RichAsyncFunction<RowData, RowData> {
+public abstract class AbstractAsyncFunctionRunner<T> extends
RichAsyncFunction<RowData, RowData> {
- private static final long serialVersionUID = -7198305381139008806L;
+ protected final GeneratedFunction<AsyncFunction<RowData, T>>
generatedFetcher;
- private final GeneratedFunction<AsyncFunction<RowData, RowData>>
generatedFetcher;
+ protected transient AsyncFunction<RowData, T> fetcher;
- private transient AsyncFunction<RowData, RowData> fetcher;
-
- public AsyncFunctionRunner(
- GeneratedFunction<AsyncFunction<RowData, RowData>>
generatedFetcher) {
+ public AbstractAsyncFunctionRunner(
+ GeneratedFunction<AsyncFunction<RowData, T>> generatedFetcher) {
this.generatedFetcher = generatedFetcher;
}
@@ -51,18 +50,11 @@ public class AsyncFunctionRunner extends
RichAsyncFunction<RowData, RowData> {
FunctionUtils.openFunction(fetcher, openContext);
}
- @Override
- public void asyncInvoke(RowData input, ResultFuture<RowData> resultFuture)
{
- try {
- fetcher.asyncInvoke(input, resultFuture);
- } catch (Throwable t) {
- resultFuture.completeExceptionally(t);
- }
- }
-
@Override
public void close() throws Exception {
super.close();
- FunctionUtils.closeFunction(fetcher);
+ if (fetcher != null) {
+ FunctionUtils.closeFunction(fetcher);
+ }
}
}
diff --git
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/calc/async/AsyncFunctionRunner.java
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/AbstractFunctionRunner.java
similarity index 53%
copy from
flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/calc/async/AsyncFunctionRunner.java
copy to
flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/AbstractFunctionRunner.java
index 456f07bcd96..126aeba2204 100644
---
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/calc/async/AsyncFunctionRunner.java
+++
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/AbstractFunctionRunner.java
@@ -16,53 +16,45 @@
* limitations under the License.
*/
-package org.apache.flink.table.runtime.operators.calc.async;
+package org.apache.flink.table.runtime.operators;
+import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.OpenContext;
import org.apache.flink.api.common.functions.util.FunctionUtils;
-import org.apache.flink.streaming.api.functions.async.AsyncFunction;
-import org.apache.flink.streaming.api.functions.async.ResultFuture;
-import org.apache.flink.streaming.api.functions.async.RichAsyncFunction;
+import org.apache.flink.streaming.api.functions.ProcessFunction;
import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.functions.LookupFunction;
import org.apache.flink.table.runtime.generated.GeneratedFunction;
/**
- * Async function runner for {@link
org.apache.flink.table.functions.AsyncScalarFunction}, which
- * takes the generated function, instantiates it, and then calls its lifecycle
methods.
+ * Base function runner for specialized table function, e.g. {@link
LookupFunction} or {@link
+ * ProcessFunction}.
*/
-public class AsyncFunctionRunner extends RichAsyncFunction<RowData, RowData> {
+public abstract class AbstractFunctionRunner extends ProcessFunction<RowData,
RowData> {
- private static final long serialVersionUID = -7198305381139008806L;
+ private final GeneratedFunction<FlatMapFunction<RowData, RowData>>
generatedFetcher;
- private final GeneratedFunction<AsyncFunction<RowData, RowData>>
generatedFetcher;
+ protected transient FlatMapFunction<RowData, RowData> fetcher;
- private transient AsyncFunction<RowData, RowData> fetcher;
-
- public AsyncFunctionRunner(
- GeneratedFunction<AsyncFunction<RowData, RowData>>
generatedFetcher) {
+ public AbstractFunctionRunner(
+ GeneratedFunction<FlatMapFunction<RowData, RowData>>
generatedFetcher) {
this.generatedFetcher = generatedFetcher;
}
@Override
public void open(OpenContext openContext) throws Exception {
super.open(openContext);
- fetcher =
generatedFetcher.newInstance(getRuntimeContext().getUserCodeClassLoader());
+ this.fetcher =
generatedFetcher.newInstance(getRuntimeContext().getUserCodeClassLoader());
+
FunctionUtils.setFunctionRuntimeContext(fetcher, getRuntimeContext());
FunctionUtils.openFunction(fetcher, openContext);
}
- @Override
- public void asyncInvoke(RowData input, ResultFuture<RowData> resultFuture)
{
- try {
- fetcher.asyncInvoke(input, resultFuture);
- } catch (Throwable t) {
- resultFuture.completeExceptionally(t);
- }
- }
-
@Override
public void close() throws Exception {
+ if (fetcher != null) {
+ FunctionUtils.closeFunction(fetcher);
+ }
super.close();
- FunctionUtils.closeFunction(fetcher);
}
}
diff --git
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/calc/async/AsyncFunctionRunner.java
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/calc/async/AsyncFunctionRunner.java
index 456f07bcd96..7cc3a7273f7 100644
---
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/calc/async/AsyncFunctionRunner.java
+++
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/calc/async/AsyncFunctionRunner.java
@@ -18,37 +18,23 @@
package org.apache.flink.table.runtime.operators.calc.async;
-import org.apache.flink.api.common.functions.OpenContext;
-import org.apache.flink.api.common.functions.util.FunctionUtils;
import org.apache.flink.streaming.api.functions.async.AsyncFunction;
import org.apache.flink.streaming.api.functions.async.ResultFuture;
-import org.apache.flink.streaming.api.functions.async.RichAsyncFunction;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.runtime.generated.GeneratedFunction;
+import org.apache.flink.table.runtime.operators.AbstractAsyncFunctionRunner;
/**
* Async function runner for {@link
org.apache.flink.table.functions.AsyncScalarFunction}, which
* takes the generated function, instantiates it, and then calls its lifecycle
methods.
*/
-public class AsyncFunctionRunner extends RichAsyncFunction<RowData, RowData> {
+public class AsyncFunctionRunner extends AbstractAsyncFunctionRunner<RowData> {
private static final long serialVersionUID = -7198305381139008806L;
- private final GeneratedFunction<AsyncFunction<RowData, RowData>>
generatedFetcher;
-
- private transient AsyncFunction<RowData, RowData> fetcher;
-
public AsyncFunctionRunner(
GeneratedFunction<AsyncFunction<RowData, RowData>>
generatedFetcher) {
- this.generatedFetcher = generatedFetcher;
- }
-
- @Override
- public void open(OpenContext openContext) throws Exception {
- super.open(openContext);
- fetcher =
generatedFetcher.newInstance(getRuntimeContext().getUserCodeClassLoader());
- FunctionUtils.setFunctionRuntimeContext(fetcher, getRuntimeContext());
- FunctionUtils.openFunction(fetcher, openContext);
+ super(generatedFetcher);
}
@Override
@@ -59,10 +45,4 @@ public class AsyncFunctionRunner extends
RichAsyncFunction<RowData, RowData> {
resultFuture.completeExceptionally(t);
}
}
-
- @Override
- public void close() throws Exception {
- super.close();
- FunctionUtils.closeFunction(fetcher);
- }
}
diff --git
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/AsyncLookupJoinRunner.java
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/AsyncLookupJoinRunner.java
index 310f38d5489..6b34c99f4a7 100644
---
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/AsyncLookupJoinRunner.java
+++
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/AsyncLookupJoinRunner.java
@@ -26,7 +26,6 @@ import org.apache.flink.configuration.Configuration;
import org.apache.flink.streaming.api.functions.async.AsyncFunction;
import org.apache.flink.streaming.api.functions.async.CollectionSupplier;
import org.apache.flink.streaming.api.functions.async.ResultFuture;
-import org.apache.flink.streaming.api.functions.async.RichAsyncFunction;
import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.conversion.DataStructureConverter;
@@ -35,6 +34,7 @@ import
org.apache.flink.table.runtime.collector.TableFunctionResultFuture;
import org.apache.flink.table.runtime.generated.FilterCondition;
import org.apache.flink.table.runtime.generated.GeneratedFunction;
import org.apache.flink.table.runtime.generated.GeneratedResultFuture;
+import org.apache.flink.table.runtime.operators.AbstractAsyncFunctionRunner;
import org.apache.flink.table.runtime.typeutils.RowDataSerializer;
import java.util.ArrayList;
@@ -45,10 +45,9 @@ import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.BlockingQueue;
/** The async join runner to lookup the dimension table. */
-public class AsyncLookupJoinRunner extends RichAsyncFunction<RowData, RowData>
{
+public class AsyncLookupJoinRunner extends AbstractAsyncFunctionRunner<Object>
{
private static final long serialVersionUID = -6664660022391632480L;
- private final GeneratedFunction<AsyncFunction<RowData, Object>>
generatedFetcher;
private final DataStructureConverter<RowData, Object> fetcherConverter;
private final GeneratedResultFuture<TableFunctionResultFuture<RowData>>
generatedResultFuture;
private final GeneratedFunction<FilterCondition>
generatedPreFilterCondition;
@@ -56,8 +55,6 @@ public class AsyncLookupJoinRunner extends
RichAsyncFunction<RowData, RowData> {
private final boolean isLeftOuterJoin;
private final int asyncBufferCapacity;
- private transient AsyncFunction<RowData, Object> fetcher;
-
protected final RowDataSerializer rightRowSerializer;
/**
@@ -83,7 +80,7 @@ public class AsyncLookupJoinRunner extends
RichAsyncFunction<RowData, RowData> {
RowDataSerializer rightRowSerializer,
boolean isLeftOuterJoin,
int asyncBufferCapacity) {
- this.generatedFetcher = generatedFetcher;
+ super(generatedFetcher);
this.fetcherConverter = fetcherConverter;
this.generatedResultFuture = generatedResultFuture;
this.generatedPreFilterCondition = generatedPreFilterCondition;
@@ -96,11 +93,9 @@ public class AsyncLookupJoinRunner extends
RichAsyncFunction<RowData, RowData> {
public void open(OpenContext openContext) throws Exception {
super.open(openContext);
ClassLoader cl = getRuntimeContext().getUserCodeClassLoader();
- this.fetcher = generatedFetcher.newInstance(cl);
this.preFilterCondition = generatedPreFilterCondition.newInstance(cl);
FunctionUtils.setFunctionRuntimeContext(fetcher, getRuntimeContext());
FunctionUtils.setFunctionRuntimeContext(preFilterCondition,
getRuntimeContext());
- FunctionUtils.openFunction(fetcher, openContext);
FunctionUtils.openFunction(preFilterCondition, openContext);
// try to compile the generated ResultFuture, fail fast if the code is
corrupt.
@@ -152,9 +147,6 @@ public class AsyncLookupJoinRunner extends
RichAsyncFunction<RowData, RowData> {
@Override
public void close() throws Exception {
super.close();
- if (fetcher != null) {
- FunctionUtils.closeFunction(fetcher);
- }
if (preFilterCondition != null) {
FunctionUtils.closeFunction(preFilterCondition);
}
diff --git
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/LookupJoinRunner.java
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/LookupJoinRunner.java
index c91a9feb7bf..4ecb0b53f59 100644
---
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/LookupJoinRunner.java
+++
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/LookupJoinRunner.java
@@ -21,7 +21,6 @@ package org.apache.flink.table.runtime.operators.join.lookup;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.OpenContext;
import org.apache.flink.api.common.functions.util.FunctionUtils;
-import org.apache.flink.streaming.api.functions.ProcessFunction;
import org.apache.flink.table.data.GenericRowData;
import org.apache.flink.table.data.RowData;
import org.apache.flink.table.data.utils.JoinedRowData;
@@ -29,20 +28,19 @@ import
org.apache.flink.table.runtime.collector.ListenableCollector;
import org.apache.flink.table.runtime.generated.FilterCondition;
import org.apache.flink.table.runtime.generated.GeneratedCollector;
import org.apache.flink.table.runtime.generated.GeneratedFunction;
+import org.apache.flink.table.runtime.operators.AbstractFunctionRunner;
import org.apache.flink.util.Collector;
/** The join runner to lookup the dimension table. */
-public class LookupJoinRunner extends ProcessFunction<RowData, RowData> {
+public class LookupJoinRunner extends AbstractFunctionRunner {
private static final long serialVersionUID = -4521543015709964733L;
- private final GeneratedFunction<FlatMapFunction<RowData, RowData>>
generatedFetcher;
private final GeneratedCollector<ListenableCollector<RowData>>
generatedCollector;
private final GeneratedFunction<FilterCondition>
generatedPreFilterCondition;
protected final boolean isLeftOuterJoin;
protected final int tableFieldsCount;
- private transient FlatMapFunction<RowData, RowData> fetcher;
protected transient ListenableCollector<RowData> collector;
protected transient JoinedRowData outRow;
protected transient FilterCondition preFilterCondition;
@@ -54,7 +52,7 @@ public class LookupJoinRunner extends
ProcessFunction<RowData, RowData> {
GeneratedFunction<FilterCondition> generatedPreFilterCondition,
boolean isLeftOuterJoin,
int tableFieldsCount) {
- this.generatedFetcher = generatedFetcher;
+ super(generatedFetcher);
this.generatedCollector = generatedCollector;
this.generatedPreFilterCondition = generatedPreFilterCondition;
this.isLeftOuterJoin = isLeftOuterJoin;
@@ -64,17 +62,14 @@ public class LookupJoinRunner extends
ProcessFunction<RowData, RowData> {
@Override
public void open(OpenContext openContext) throws Exception {
super.open(openContext);
- this.fetcher =
generatedFetcher.newInstance(getRuntimeContext().getUserCodeClassLoader());
this.collector =
generatedCollector.newInstance(getRuntimeContext().getUserCodeClassLoader());
this.preFilterCondition =
generatedPreFilterCondition.newInstance(
getRuntimeContext().getUserCodeClassLoader());
- FunctionUtils.setFunctionRuntimeContext(fetcher, getRuntimeContext());
FunctionUtils.setFunctionRuntimeContext(collector,
getRuntimeContext());
FunctionUtils.setFunctionRuntimeContext(preFilterCondition,
getRuntimeContext());
- FunctionUtils.openFunction(fetcher, openContext);
FunctionUtils.openFunction(collector, openContext);
FunctionUtils.openFunction(preFilterCondition, openContext);
@@ -124,9 +119,6 @@ public class LookupJoinRunner extends
ProcessFunction<RowData, RowData> {
@Override
public void close() throws Exception {
- if (fetcher != null) {
- FunctionUtils.closeFunction(fetcher);
- }
if (collector != null) {
FunctionUtils.closeFunction(collector);
}
diff --git
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/ml/AsyncMLPredictRunner.java
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/ml/AsyncMLPredictRunner.java
new file mode 100644
index 00000000000..aa77eedaeaf
--- /dev/null
+++
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/ml/AsyncMLPredictRunner.java
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.operators.ml;
+
+import org.apache.flink.api.common.functions.OpenContext;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.streaming.api.functions.async.AsyncFunction;
+import org.apache.flink.streaming.api.functions.async.CollectionSupplier;
+import org.apache.flink.streaming.api.functions.async.ResultFuture;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.utils.JoinedRowData;
+import org.apache.flink.table.functions.AsyncPredictFunction;
+import org.apache.flink.table.runtime.generated.GeneratedFunction;
+import org.apache.flink.table.runtime.operators.AbstractAsyncFunctionRunner;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.concurrent.ArrayBlockingQueue;
+import java.util.concurrent.BlockingQueue;
+
+/**
+ * Async function runner for {@link AsyncPredictFunction}, which takes the
generated function,
+ * instantiates it, and then calls its lifecycle methods.
+ */
+public class AsyncMLPredictRunner extends AbstractAsyncFunctionRunner<RowData>
{
+
+ private final int asyncBufferCapacity;
+
+ /**
+ * Buffers {@link ResultFuture} to avoid newInstance cost when processing
elements every time.
+ * We use {@link BlockingQueue} to make sure the head {@link
ResultFuture}s are available.
+ */
+ private transient BlockingQueue<JoinedRowResultFuture> resultFutureBuffer;
+
+ public AsyncMLPredictRunner(
+ GeneratedFunction<AsyncFunction<RowData, RowData>>
generatedFetcher,
+ int asyncBufferCapacity) {
+ super(generatedFetcher);
+ this.asyncBufferCapacity = asyncBufferCapacity;
+ }
+
+ @Override
+ public void open(OpenContext openContext) throws Exception {
+ super.open(openContext);
+ this.resultFutureBuffer = new ArrayBlockingQueue<>(asyncBufferCapacity
+ 1);
+ for (int i = 0; i < asyncBufferCapacity + 1; i++) {
+ JoinedRowResultFuture rf = new
JoinedRowResultFuture(resultFutureBuffer);
+ // add will throw exception immediately if the queue is full which
should never happen
+ resultFutureBuffer.add(rf);
+ }
+ registerMetric(getRuntimeContext().getMetricGroup());
+ }
+
+ @Override
+ public void asyncInvoke(RowData input, ResultFuture<RowData> resultFuture)
throws Exception {
+ try {
+ JoinedRowResultFuture buffer = resultFutureBuffer.take();
+ buffer.reset(input, resultFuture);
+ fetcher.asyncInvoke(input, buffer);
+ } catch (Throwable t) {
+ resultFuture.completeExceptionally(t);
+ }
+ }
+
+ private void registerMetric(MetricGroup metricGroup) {
+ metricGroup.gauge(
+ "ai_queue_length", () -> asyncBufferCapacity + 1 -
resultFutureBuffer.size());
+ metricGroup.gauge("ai_queue_capacity", () -> asyncBufferCapacity);
+ metricGroup.gauge(
+ "ai_queue_usage_ratio",
+ () ->
+ 1.0
+ * (asyncBufferCapacity + 1 -
resultFutureBuffer.size())
+ / asyncBufferCapacity);
+ }
+
+ private static final class JoinedRowResultFuture implements
ResultFuture<RowData> {
+
+ private final BlockingQueue<JoinedRowResultFuture> resultFutureBuffer;
+
+ private ResultFuture<RowData> realOutput;
+ private RowData leftRow;
+
+ public JoinedRowResultFuture(BlockingQueue<JoinedRowResultFuture>
resultFutureBuffer) {
+ this.resultFutureBuffer = resultFutureBuffer;
+ }
+
+ public void reset(RowData row, ResultFuture<RowData> realOutput) {
+ this.realOutput = realOutput;
+ this.leftRow = row;
+ }
+
+ @Override
+ public void complete(Collection<RowData> result) {
+ List<RowData> outRows = new ArrayList<>();
+ for (RowData rightRow : result) {
+ RowData outRow = new JoinedRowData(leftRow.getRowKind(),
leftRow, rightRow);
+ outRows.add(outRow);
+ }
+ realOutput.complete(outRows);
+
+ try {
+ // put this collector to the queue to avoid this collector is
used
+ // again before outRows in the collector is not consumed.
+ resultFutureBuffer.put(this);
+ } catch (InterruptedException e) {
+ completeExceptionally(e);
+ }
+ }
+
+ @Override
+ public void completeExceptionally(Throwable error) {
+ realOutput.completeExceptionally(error);
+ }
+
+ @Override
+ public void complete(CollectionSupplier<RowData> supplier) {
+ throw new UnsupportedOperationException();
+ }
+ }
+}
diff --git
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/ml/MLPredictRunner.java
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/ml/MLPredictRunner.java
new file mode 100644
index 00000000000..a8a486911f8
--- /dev/null
+++
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/ml/MLPredictRunner.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.operators.ml;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.OpenContext;
+import org.apache.flink.api.common.functions.util.FunctionUtils;
+import org.apache.flink.streaming.api.functions.ProcessFunction;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.functions.PredictFunction;
+import org.apache.flink.table.runtime.collector.ListenableCollector;
+import org.apache.flink.table.runtime.generated.GeneratedCollector;
+import org.apache.flink.table.runtime.generated.GeneratedFunction;
+import org.apache.flink.table.runtime.operators.AbstractFunctionRunner;
+import org.apache.flink.util.Collector;
+
+/**
+ * Function runner for {@link PredictFunction}, which takes the generated
function, instantiates it,
+ * and then calls its lifecycle methods.
+ */
+public class MLPredictRunner extends AbstractFunctionRunner {
+
+ private final GeneratedCollector<ListenableCollector<RowData>>
generatedCollector;
+
+ protected transient ListenableCollector<RowData> collector;
+
+ public MLPredictRunner(
+ GeneratedFunction<FlatMapFunction<RowData, RowData>>
generatedFetcher,
+ GeneratedCollector<ListenableCollector<RowData>>
generatedCollector) {
+ super(generatedFetcher);
+ this.generatedCollector = generatedCollector;
+ }
+
+ @Override
+ public void open(OpenContext openContext) throws Exception {
+ super.open(openContext);
+
+ this.collector =
+
generatedCollector.newInstance(getRuntimeContext().getUserCodeClassLoader());
+ FunctionUtils.setFunctionRuntimeContext(collector,
getRuntimeContext());
+ FunctionUtils.openFunction(collector, openContext);
+ }
+
+ @Override
+ public void processElement(
+ RowData in, ProcessFunction<RowData, RowData>.Context ctx,
Collector<RowData> out)
+ throws Exception {
+ prepareCollector(in, out);
+ fetcher.flatMap(in, collector);
+ }
+
+ public void prepareCollector(RowData in, Collector<RowData> out) {
+ collector.setCollector(out);
+ collector.setInput(in);
+ collector.reset();
+ }
+}