This is an automated email from the ASF dual-hosted git repository.

shengkai pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 8b6e69e6f7e [FLINK-38435][table] Refactor codegen and runner for 
MLPredict and LookupJoin (#27041)
8b6e69e6f7e is described below

commit 8b6e69e6f7ecb105dd7478a3613c2a656de4e9e9
Author: Shengkai <[email protected]>
AuthorDate: Thu Oct 16 09:40:08 2025 +0800

    [FLINK-38435][table] Refactor codegen and runner for MLPredict and 
LookupJoin (#27041)
---
 ...upCallContext.java => FunctionCallContext.java} |  24 +-
 .../nodes/exec/common/CommonExecLookupJoin.java    |   3 +-
 .../nodes/exec/stream/StreamExecDeltaJoin.java     |   3 +-
 .../stream/StreamExecMLPredictTableFunction.java   |  69 +---
 .../codegen/FunctionCallCodeGenerator.scala        | 344 ++++++++++++++++++
 .../planner/codegen/LookupJoinCodeGenerator.scala  | 403 +++++++--------------
 .../planner/codegen/MLPredictCodeGenerator.scala   | 144 ++++++++
 ...unner.java => AbstractAsyncFunctionRunner.java} |  34 +-
 ...tionRunner.java => AbstractFunctionRunner.java} |  40 +-
 .../operators/calc/async/AsyncFunctionRunner.java  |  26 +-
 .../join/lookup/AsyncLookupJoinRunner.java         |  14 +-
 .../operators/join/lookup/LookupJoinRunner.java    |  14 +-
 .../runtime/operators/ml/AsyncMLPredictRunner.java | 138 +++++++
 .../runtime/operators/ml/MLPredictRunner.java      |  73 ++++
 14 files changed, 896 insertions(+), 433 deletions(-)

diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/LookupCallContext.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/FunctionCallContext.java
similarity index 87%
rename from 
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/LookupCallContext.java
rename to 
flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/FunctionCallContext.java
index ce655466c0f..562012f225b 100644
--- 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/LookupCallContext.java
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/inference/FunctionCallContext.java
@@ -21,6 +21,7 @@ import org.apache.flink.annotation.Internal;
 import org.apache.flink.table.catalog.DataTypeFactory;
 import org.apache.flink.table.connector.source.LookupTableSource;
 import org.apache.flink.table.functions.UserDefinedFunction;
+import org.apache.flink.table.ml.PredictRuntimeProvider;
 import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.Constant;
 import 
org.apache.flink.table.planner.plan.utils.FunctionCallUtil.FunctionParam;
 import org.apache.flink.table.types.DataType;
@@ -38,24 +39,27 @@ import static 
org.apache.flink.table.planner.plan.utils.FunctionCallUtil.FieldRe
 import static 
org.apache.flink.table.types.logical.utils.LogicalTypeChecks.getFieldTypes;
 import static 
org.apache.flink.table.types.utils.TypeConversions.fromLogicalToDataType;
 
-/** The {@link CallContext} of a {@link LookupTableSource} runtime function. */
+/**
+ * The {@link CallContext} of {@link LookupTableSource}, {@link 
PredictRuntimeProvider} runtime
+ * function.
+ */
 @Internal
-public class LookupCallContext extends AbstractSqlCallContext {
+public class FunctionCallContext extends AbstractSqlCallContext {
 
-    private final List<FunctionParam> lookupKeys;
+    private final List<FunctionParam> params;
 
     private final List<DataType> argumentDataTypes;
 
     private final DataType outputDataType;
 
-    public LookupCallContext(
+    public FunctionCallContext(
             DataTypeFactory dataTypeFactory,
             UserDefinedFunction function,
             LogicalType inputType,
-            List<FunctionParam> lookupKeys,
-            LogicalType lookupType) {
+            List<FunctionParam> params,
+            LogicalType outputDataType) {
         super(dataTypeFactory, function, generateInlineFunctionName(function), 
false);
-        this.lookupKeys = lookupKeys;
+        this.params = params;
         this.argumentDataTypes =
                 new AbstractList<>() {
                     @Override
@@ -74,10 +78,10 @@ public class LookupCallContext extends 
AbstractSqlCallContext {
 
                     @Override
                     public int size() {
-                        return lookupKeys.size();
+                        return params.size();
                     }
                 };
-        this.outputDataType = fromLogicalToDataType(lookupType);
+        this.outputDataType = fromLogicalToDataType(outputDataType);
     }
 
     @Override
@@ -118,6 +122,6 @@ public class LookupCallContext extends 
AbstractSqlCallContext {
     // 
--------------------------------------------------------------------------------------------
 
     private FunctionParam getKey(int pos) {
-        return lookupKeys.get(pos);
+        return params.get(pos);
     }
 }
diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java
index bbf4486c5ed..05b07458a10 100644
--- 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java
@@ -46,6 +46,7 @@ import org.apache.flink.table.legacy.sources.TableSource;
 import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
 import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
 import org.apache.flink.table.planner.codegen.FilterCodeGenerator;
+import org.apache.flink.table.planner.codegen.FunctionCallCodeGenerator;
 import org.apache.flink.table.planner.codegen.LookupJoinCodeGenerator;
 import org.apache.flink.table.planner.delegation.PlannerBase;
 import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
@@ -459,7 +460,7 @@ public abstract class CommonExecLookupJoin extends 
ExecNodeBase<RowData> {
                         .mapToObj(allLookupKeys::get)
                         .collect(Collectors.toList());
 
-        
LookupJoinCodeGenerator.GeneratedTableFunctionWithDataType<AsyncFunction<RowData,
 Object>>
+        
FunctionCallCodeGenerator.GeneratedTableFunctionWithDataType<AsyncFunction<RowData,
 Object>>
                 generatedFuncWithType =
                         LookupJoinCodeGenerator.generateAsyncLookupFunction(
                                 config,
diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java
index 39a48443e0d..6d368618c91 100644
--- 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java
@@ -31,6 +31,7 @@ import 
org.apache.flink.table.data.conversion.DataStructureConverters;
 import org.apache.flink.table.functions.AsyncTableFunction;
 import org.apache.flink.table.functions.UserDefinedFunctionHelper;
 import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
+import org.apache.flink.table.planner.codegen.FunctionCallCodeGenerator;
 import org.apache.flink.table.planner.codegen.LookupJoinCodeGenerator;
 import org.apache.flink.table.planner.delegation.PlannerBase;
 import org.apache.flink.table.planner.plan.nodes.exec.ExecEdge;
@@ -353,7 +354,7 @@ public class StreamExecDeltaJoin extends 
ExecNodeBase<RowData>
                         .mapToObj(lookupKeys::get)
                         .collect(Collectors.toList());
 
-        
LookupJoinCodeGenerator.GeneratedTableFunctionWithDataType<AsyncFunction<RowData,
 Object>>
+        
FunctionCallCodeGenerator.GeneratedTableFunctionWithDataType<AsyncFunction<RowData,
 Object>>
                 lookupSideGeneratedFuncWithType =
                         LookupJoinCodeGenerator.generateAsyncLookupFunction(
                                 config,
diff --git 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecMLPredictTableFunction.java
 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecMLPredictTableFunction.java
index 0ff85e23031..6d1fe8a8cee 100644
--- 
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecMLPredictTableFunction.java
+++ 
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecMLPredictTableFunction.java
@@ -31,8 +31,6 @@ import 
org.apache.flink.streaming.api.operators.async.AsyncWaitOperatorFactory;
 import org.apache.flink.table.api.TableException;
 import org.apache.flink.table.catalog.DataTypeFactory;
 import org.apache.flink.table.data.RowData;
-import org.apache.flink.table.data.conversion.DataStructureConverter;
-import org.apache.flink.table.data.conversion.DataStructureConverters;
 import org.apache.flink.table.functions.AsyncPredictFunction;
 import org.apache.flink.table.functions.PredictFunction;
 import org.apache.flink.table.functions.UserDefinedFunction;
@@ -41,8 +39,8 @@ import org.apache.flink.table.ml.ModelProvider;
 import org.apache.flink.table.ml.PredictRuntimeProvider;
 import org.apache.flink.table.planner.calcite.FlinkContext;
 import org.apache.flink.table.planner.codegen.CodeGeneratorContext;
-import org.apache.flink.table.planner.codegen.FilterCodeGenerator;
-import org.apache.flink.table.planner.codegen.LookupJoinCodeGenerator;
+import org.apache.flink.table.planner.codegen.FunctionCallCodeGenerator;
+import org.apache.flink.table.planner.codegen.MLPredictCodeGenerator;
 import org.apache.flink.table.planner.delegation.PlannerBase;
 import org.apache.flink.table.planner.plan.nodes.exec.ExecNode;
 import org.apache.flink.table.planner.plan.nodes.exec.ExecNodeBase;
@@ -55,18 +53,15 @@ import 
org.apache.flink.table.planner.plan.nodes.exec.spec.MLPredictSpec;
 import org.apache.flink.table.planner.plan.nodes.exec.spec.ModelSpec;
 import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil;
 import org.apache.flink.table.planner.plan.utils.FunctionCallUtil;
-import org.apache.flink.table.planner.utils.JavaScalaConversionUtil;
 import org.apache.flink.table.runtime.collector.ListenableCollector;
-import org.apache.flink.table.runtime.collector.TableFunctionResultFuture;
 import 
org.apache.flink.table.runtime.functions.ml.ModelPredictRuntimeProviderContext;
 import org.apache.flink.table.runtime.generated.GeneratedCollector;
 import org.apache.flink.table.runtime.generated.GeneratedFunction;
-import org.apache.flink.table.runtime.generated.GeneratedResultFuture;
-import 
org.apache.flink.table.runtime.operators.join.lookup.AsyncLookupJoinRunner;
-import org.apache.flink.table.runtime.operators.join.lookup.LookupJoinRunner;
-import org.apache.flink.table.runtime.typeutils.InternalSerializers;
+import org.apache.flink.table.runtime.operators.ml.AsyncMLPredictRunner;
+import org.apache.flink.table.runtime.operators.ml.MLPredictRunner;
 import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
 import org.apache.flink.table.types.logical.RowType;
+import org.apache.flink.util.Preconditions;
 
 import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonCreator;
 import 
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.annotation.JsonProperty;
@@ -75,7 +70,6 @@ import javax.annotation.Nullable;
 
 import java.util.Collections;
 import java.util.List;
-import java.util.Optional;
 
 /** Stream {@link ExecNode} for {@code ML_PREDICT}. */
 @ExecNodeMetadata(
@@ -197,7 +191,7 @@ public class StreamExecMLPredictTableFunction extends 
ExecNodeBase<RowData>
             RowType resultRowType,
             PredictFunction predictFunction) {
         GeneratedFunction<FlatMapFunction<RowData, RowData>> generatedFetcher =
-                LookupJoinCodeGenerator.generateSyncLookupFunction(
+                MLPredictCodeGenerator.generateSyncPredictFunction(
                         config,
                         classLoader,
                         dataTypeFactory,
@@ -206,25 +200,15 @@ public class StreamExecMLPredictTableFunction extends 
ExecNodeBase<RowData>
                         resultRowType,
                         mlPredictSpec.getFeatures(),
                         predictFunction,
-                        "MLPredict",
+                        
modelSpec.getContextResolvedModel().getIdentifier().asSummaryString(),
                         config.get(PipelineOptions.OBJECT_REUSE));
         GeneratedCollector<ListenableCollector<RowData>> generatedCollector =
-                LookupJoinCodeGenerator.generateCollector(
+                MLPredictCodeGenerator.generateCollector(
                         new CodeGeneratorContext(config, classLoader),
                         inputRowType,
                         modelOutputType,
-                        (RowType) getOutputType(),
-                        JavaScalaConversionUtil.toScala(Optional.empty()),
-                        JavaScalaConversionUtil.toScala(Optional.empty()),
-                        true);
-        LookupJoinRunner mlPredictRunner =
-                new LookupJoinRunner(
-                        generatedFetcher,
-                        generatedCollector,
-                        FilterCodeGenerator.generateFilterCondition(
-                                config, classLoader, null, inputRowType),
-                        false,
-                        modelOutputType.getFieldCount());
+                        (RowType) getOutputType());
+        MLPredictRunner mlPredictRunner = new 
MLPredictRunner(generatedFetcher, generatedCollector);
         SimpleOperatorFactory<RowData> operatorFactory =
                 SimpleOperatorFactory.of(new 
ProcessOperator<>(mlPredictRunner));
         return ExecNodeUtil.createOneInputTransformation(
@@ -246,9 +230,9 @@ public class StreamExecMLPredictTableFunction extends 
ExecNodeBase<RowData>
             RowType modelOutputType,
             RowType resultRowType,
             AsyncPredictFunction asyncPredictFunction) {
-        
LookupJoinCodeGenerator.GeneratedTableFunctionWithDataType<AsyncFunction<RowData,
 Object>>
+        
FunctionCallCodeGenerator.GeneratedTableFunctionWithDataType<AsyncFunction<RowData,
 Object>>
                 generatedFuncWithType =
-                        LookupJoinCodeGenerator.generateAsyncLookupFunction(
+                        MLPredictCodeGenerator.generateAsyncPredictFunction(
                                 config,
                                 classLoader,
                                 dataTypeFactory,
@@ -257,29 +241,14 @@ public class StreamExecMLPredictTableFunction extends 
ExecNodeBase<RowData>
                                 resultRowType,
                                 mlPredictSpec.getFeatures(),
                                 asyncPredictFunction,
-                                "AsyncMLPredict");
-
-        GeneratedResultFuture<TableFunctionResultFuture<RowData>> 
generatedResultFuture =
-                LookupJoinCodeGenerator.generateTableAsyncCollector(
-                        config,
-                        classLoader,
-                        "TableFunctionResultFuture",
-                        inputRowType,
-                        modelOutputType,
-                        JavaScalaConversionUtil.toScala(Optional.empty()));
-
-        DataStructureConverter<?, ?> fetcherConverter =
-                
DataStructureConverters.getConverter(generatedFuncWithType.dataType());
+                                modelSpec
+                                        .getContextResolvedModel()
+                                        .getIdentifier()
+                                        .asSummaryString());
         AsyncFunction<RowData, RowData> asyncFunc =
-                new AsyncLookupJoinRunner(
-                        generatedFuncWithType.tableFunc(),
-                        (DataStructureConverter<RowData, Object>) 
fetcherConverter,
-                        generatedResultFuture,
-                        FilterCodeGenerator.generateFilterCondition(
-                                config, classLoader, null, inputRowType),
-                        InternalSerializers.create(modelOutputType),
-                        false,
-                        asyncOptions.asyncBufferCapacity);
+                new AsyncMLPredictRunner(
+                        (GeneratedFunction) generatedFuncWithType.tableFunc(),
+                        
Preconditions.checkNotNull(asyncOptions).asyncBufferCapacity);
         return ExecNodeUtil.createOneInputTransformation(
                 inputTransformation,
                 createTransformationMeta(ML_PREDICT_TRANSFORMATION, config),
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/FunctionCallCodeGenerator.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/FunctionCallCodeGenerator.scala
new file mode 100644
index 00000000000..438f14c21e3
--- /dev/null
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/FunctionCallCodeGenerator.scala
@@ -0,0 +1,344 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.planner.codegen
+
+import org.apache.flink.api.common.functions.{FlatMapFunction, Function, 
OpenContext}
+import org.apache.flink.configuration.ReadableConfig
+import org.apache.flink.streaming.api.functions.async.AsyncFunction
+import org.apache.flink.table.catalog.DataTypeFactory
+import org.apache.flink.table.data.{GenericRowData, RowData}
+import org.apache.flink.table.data.utils.JoinedRowData
+import org.apache.flink.table.functions.{AsyncTableFunction, TableFunction, 
UserDefinedFunction, UserDefinedFunctionHelper}
+import 
org.apache.flink.table.planner.codegen.CodeGenUtils.{boxedTypeTermForType, 
className, newName, DEFAULT_COLLECTOR_TERM, DEFAULT_INPUT1_TERM, 
DEFAULT_INPUT2_TERM}
+import 
org.apache.flink.table.planner.codegen.GenerateUtils.{generateInputAccess, 
generateLiteral}
+import org.apache.flink.table.planner.delegation.PlannerBase
+import org.apache.flink.table.planner.functions.inference.FunctionCallContext
+import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.{Constant, 
FieldRef, FunctionParam}
+import org.apache.flink.table.planner.plan.utils.RexLiteralUtil
+import org.apache.flink.table.runtime.collector.ListenableCollector
+import 
org.apache.flink.table.runtime.collector.ListenableCollector.CollectListener
+import org.apache.flink.table.runtime.generated.{GeneratedCollector, 
GeneratedFunction}
+import org.apache.flink.table.types.DataType
+import org.apache.flink.table.types.logical.{LogicalType, RowType}
+import org.apache.flink.util.Collector
+
+import org.apache.calcite.rex.RexNode
+
+import java.util
+
+import scala.collection.JavaConverters._
+
+object FunctionCallCodeGenerator {
+
+  case class GeneratedTableFunctionWithDataType[F <: Function](
+      tableFunc: GeneratedFunction[F],
+      dataType: DataType)
+
+  /** Generates a sync function ([[TableFunction]]) call. */
+  def generateSyncFunctionCall(
+      tableConfig: ReadableConfig,
+      classLoader: ClassLoader,
+      dataTypeFactory: DataTypeFactory,
+      inputType: LogicalType,
+      functionOutputType: LogicalType,
+      collectorOutputType: LogicalType,
+      parameters: util.List[FunctionParam],
+      syncFunctionDefinition: TableFunction[_],
+      inferCall: (
+          CodeGeneratorContext,
+          FunctionCallContext,
+          UserDefinedFunction,
+          Seq[GeneratedExpression]) => (GeneratedExpression, DataType),
+      functionName: String,
+      generateClassName: String,
+      fieldCopy: Boolean): 
GeneratedTableFunctionWithDataType[FlatMapFunction[RowData, RowData]] = {
+
+    val bodyCode: GeneratedExpression => String = call => {
+      val resultCollectorTerm = call.resultTerm
+      s"""
+         |$resultCollectorTerm.setCollector($DEFAULT_COLLECTOR_TERM);
+         |${call.code}
+         |""".stripMargin
+    }
+
+    generateFunctionCall(
+      classOf[FlatMapFunction[RowData, RowData]],
+      tableConfig,
+      classLoader,
+      dataTypeFactory,
+      inputType,
+      functionOutputType,
+      collectorOutputType,
+      parameters,
+      syncFunctionDefinition,
+      inferCall,
+      functionName,
+      generateClassName,
+      fieldCopy,
+      bodyCode
+    )
+  }
+
+  /** Generates an async function ([[AsyncTableFunction]]) call. */
+  def generateAsyncFunctionCall(
+      tableConfig: ReadableConfig,
+      classLoader: ClassLoader,
+      dataTypeFactory: DataTypeFactory,
+      inputType: LogicalType,
+      functionOutputType: LogicalType,
+      collectorOutputType: LogicalType,
+      parameters: util.List[FunctionParam],
+      asyncFunctionDefinition: AsyncTableFunction[_],
+      generateCallWithDataType: (
+          CodeGeneratorContext,
+          FunctionCallContext,
+          UserDefinedFunction,
+          Seq[GeneratedExpression]) => (GeneratedExpression, DataType),
+      functionName: String,
+      generateClassName: String
+  ): GeneratedTableFunctionWithDataType[AsyncFunction[RowData, AnyRef]] = {
+    generateFunctionCall(
+      classOf[AsyncFunction[RowData, AnyRef]],
+      tableConfig,
+      classLoader,
+      dataTypeFactory,
+      inputType,
+      functionOutputType,
+      collectorOutputType,
+      parameters,
+      asyncFunctionDefinition,
+      generateCallWithDataType,
+      functionName,
+      generateClassName,
+      fieldCopy = true,
+      _.code
+    )
+  }
+
+  private def generateFunctionCall[F <: Function](
+      generatedClass: Class[F],
+      tableConfig: ReadableConfig,
+      classLoader: ClassLoader,
+      dataTypeFactory: DataTypeFactory,
+      inputType: LogicalType,
+      functionOutputType: LogicalType,
+      collectorOutputType: LogicalType,
+      parameters: util.List[FunctionParam],
+      functionDefinition: UserDefinedFunction,
+      generateCallWithDataType: (
+          CodeGeneratorContext,
+          FunctionCallContext,
+          UserDefinedFunction,
+          Seq[GeneratedExpression]) => (GeneratedExpression, DataType),
+      functionName: String,
+      generateClassName: String,
+      fieldCopy: Boolean,
+      bodyCode: GeneratedExpression => String): 
GeneratedTableFunctionWithDataType[F] = {
+
+    val callContext =
+      new FunctionCallContext(
+        dataTypeFactory,
+        functionDefinition,
+        inputType,
+        parameters,
+        functionOutputType)
+
+    // create the final UDF for runtime
+    val udf = UserDefinedFunctionHelper.createSpecializedFunction(
+      functionName,
+      functionDefinition,
+      callContext,
+      classOf[PlannerBase].getClassLoader,
+      tableConfig,
+      // no need to support expression evaluation at this point
+      null
+    )
+
+    val ctx = new CodeGeneratorContext(tableConfig, classLoader)
+    val operands = prepareOperands(ctx, inputType, parameters, fieldCopy)
+
+    val callWithDataType: (GeneratedExpression, DataType) =
+      generateCallWithDataType(ctx, callContext, udf, operands)
+
+    val function = FunctionCodeGenerator.generateFunction(
+      ctx,
+      generateClassName,
+      generatedClass,
+      bodyCode(callWithDataType._1),
+      collectorOutputType,
+      inputType)
+
+    GeneratedTableFunctionWithDataType(function, callWithDataType._2)
+  }
+
+  private def prepareOperands(
+      ctx: CodeGeneratorContext,
+      inputType: LogicalType,
+      parameters: util.List[FunctionParam],
+      fieldCopy: Boolean): Seq[GeneratedExpression] = {
+
+    parameters.asScala
+      .map {
+        case constantKey: Constant =>
+          val res = RexLiteralUtil.toFlinkInternalValue(constantKey.literal)
+          generateLiteral(ctx, res.f0, res.f1)
+        case fieldKey: FieldRef =>
+          generateInputAccess(
+            ctx,
+            inputType,
+            DEFAULT_INPUT1_TERM,
+            fieldKey.index,
+            nullableInput = false,
+            fieldCopy)
+        case _ =>
+          throw new CodeGenException("Invalid parameters.")
+      }
+  }
+
+  /**
+   * Generates collector for join ([[Collector]])
+   *
+   * Differs from CommonCorrelate.generateCollector which has no real 
condition because of
+   * FLINK-7865, here we should deal with outer join type when real conditions 
filtered result.
+   */
+  def generateCollector(
+      ctx: CodeGeneratorContext,
+      inputRowType: RowType,
+      rightRowType: RowType,
+      resultRowType: RowType,
+      condition: Option[RexNode],
+      pojoFieldMapping: Option[Array[Int]],
+      retainHeader: Boolean = true): 
GeneratedCollector[ListenableCollector[RowData]] = {
+
+    val inputTerm = DEFAULT_INPUT1_TERM
+    val rightInputTerm = DEFAULT_INPUT2_TERM
+
+    val exprGenerator = new ExprCodeGenerator(ctx, nullableInput = false)
+      .bindInput(rightRowType, inputTerm = rightInputTerm, inputFieldMapping = 
pojoFieldMapping)
+
+    val rightResultExpr =
+      exprGenerator.generateConverterResultExpression(rightRowType, 
classOf[GenericRowData])
+
+    val joinedRowTerm = CodeGenUtils.newName(ctx, "joinedRow")
+    ctx.addReusableOutputRecord(resultRowType, classOf[JoinedRowData], 
joinedRowTerm)
+
+    val header = if (retainHeader) {
+      s"$joinedRowTerm.setRowKind($inputTerm.getRowKind());"
+    } else {
+      ""
+    }
+
+    val body =
+      s"""
+         |${rightResultExpr.code}
+         |$joinedRowTerm.replace($inputTerm, ${rightResultExpr.resultTerm});
+         |$header
+         |outputResult($joinedRowTerm);
+      """.stripMargin
+
+    val collectorCode = if (condition.isEmpty) {
+      body
+    } else {
+
+      val filterGenerator = new ExprCodeGenerator(ctx, nullableInput = false)
+        .bindInput(inputRowType, inputTerm)
+        .bindSecondInput(rightRowType, rightInputTerm, pojoFieldMapping)
+      val filterCondition = filterGenerator.generateExpression(condition.get)
+
+      s"""
+         |${filterCondition.code}
+         |if (${filterCondition.resultTerm}) {
+         |  $body
+         |}
+         |""".stripMargin
+    }
+
+    generateTableFunctionCollectorForJoinTable(
+      ctx,
+      "JoinTableFuncCollector",
+      collectorCode,
+      inputRowType,
+      rightRowType,
+      inputTerm = inputTerm,
+      collectedTerm = rightInputTerm)
+  }
+
+  /**
+   * The only differences against 
CollectorCodeGenerator.generateTableFunctionCollector is
+   * "super.collect" call is binding with collect join row in "body" code
+   */
+  private def generateTableFunctionCollectorForJoinTable(
+      ctx: CodeGeneratorContext,
+      name: String,
+      bodyCode: String,
+      inputType: RowType,
+      collectedType: RowType,
+      inputTerm: String = DEFAULT_INPUT1_TERM,
+      collectedTerm: String = DEFAULT_INPUT2_TERM)
+      : GeneratedCollector[ListenableCollector[RowData]] = {
+
+    val funcName = newName(ctx, name)
+    val input1TypeClass = boxedTypeTermForType(inputType)
+    val input2TypeClass = boxedTypeTermForType(collectedType)
+
+    val funcCode =
+      s"""
+      public class $funcName extends 
${classOf[ListenableCollector[_]].getCanonicalName} {
+
+        ${ctx.reuseMemberCode()}
+
+        public $funcName(Object[] references) throws Exception {
+          ${ctx.reuseInitCode()}
+        }
+
+        @Override
+        public void open(${className[OpenContext]} openContext) throws 
Exception {
+          ${ctx.reuseOpenCode()}
+        }
+
+        @Override
+        public void collect(Object record) throws Exception {
+          $input1TypeClass $inputTerm = ($input1TypeClass) getInput();
+          $input2TypeClass $collectedTerm = ($input2TypeClass) record;
+
+          // callback only when collectListener exists, equivalent to:
+          // getCollectListener().ifPresent(
+          //   listener -> ((CollectListener) listener).onCollect(record));
+          // TODO we should update code splitter's grammar file to accept 
lambda expressions.
+
+          if (getCollectListener().isPresent()) {
+             ((${classOf[CollectListener[_]].getCanonicalName}) 
getCollectListener().get())
+             .onCollect(record);
+          }
+
+          ${ctx.reuseLocalVariableCode()}
+          ${ctx.reuseInputUnboxingCode()}
+          ${ctx.reusePerRecordCode()}
+          $bodyCode
+        }
+
+        @Override
+        public void close() throws Exception {
+          ${ctx.reuseCloseCode()}
+        }
+      }
+    """.stripMargin
+
+    new GeneratedCollector(funcName, funcCode, ctx.references.toArray, 
ctx.tableConfig)
+  }
+}
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala
index e61783b6844..2f3ab1708b8 100644
--- 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala
@@ -17,7 +17,7 @@
  */
 package org.apache.flink.table.planner.codegen
 
-import org.apache.flink.api.common.functions.{FlatMapFunction, Function, 
OpenContext}
+import org.apache.flink.api.common.functions.{FlatMapFunction, OpenContext}
 import org.apache.flink.configuration.ReadableConfig
 import org.apache.flink.streaming.api.functions.async.AsyncFunction
 import org.apache.flink.table.api.ValidationException
@@ -25,17 +25,15 @@ import org.apache.flink.table.catalog.DataTypeFactory
 import org.apache.flink.table.connector.source.{LookupTableSource, 
ScanTableSource}
 import org.apache.flink.table.data.{GenericRowData, RowData}
 import org.apache.flink.table.data.utils.JoinedRowData
-import org.apache.flink.table.functions.{AsyncLookupFunction, 
AsyncPredictFunction, AsyncTableFunction, LookupFunction, PredictFunction, 
TableFunction, UserDefinedFunction, UserDefinedFunctionHelper}
+import org.apache.flink.table.functions.{AsyncLookupFunction, 
AsyncTableFunction, LookupFunction, TableFunction, UserDefinedFunction}
 import org.apache.flink.table.planner.calcite.FlinkTypeFactory
 import org.apache.flink.table.planner.codegen.CodeGenUtils._
-import org.apache.flink.table.planner.codegen.GenerateUtils._
+import 
org.apache.flink.table.planner.codegen.FunctionCallCodeGenerator.GeneratedTableFunctionWithDataType
 import org.apache.flink.table.planner.codegen.Indenter.toISC
 import org.apache.flink.table.planner.codegen.calls.BridgingFunctionGenUtil
 import 
org.apache.flink.table.planner.codegen.calls.BridgingFunctionGenUtil.verifyFunctionAwareImplementation
-import org.apache.flink.table.planner.delegation.PlannerBase
-import org.apache.flink.table.planner.functions.inference.LookupCallContext
-import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.{Constant, 
FieldRef, FunctionParam}
-import org.apache.flink.table.planner.plan.utils.RexLiteralUtil
+import org.apache.flink.table.planner.functions.inference.FunctionCallContext
+import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.FunctionParam
 import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala
 import org.apache.flink.table.runtime.collector.{ListenableCollector, 
TableFunctionResultFuture}
 import 
org.apache.flink.table.runtime.collector.ListenableCollector.CollectListener
@@ -57,10 +55,6 @@ import scala.collection.JavaConverters._
 
 object LookupJoinCodeGenerator {
 
-  case class GeneratedTableFunctionWithDataType[F <: Function](
-      tableFunc: GeneratedFunction[F],
-      dataType: DataType)
-
   private val ARRAY_LIST = className[util.ArrayList[_]]
 
   /** Generates a lookup function ([[TableFunction]]) */
@@ -76,29 +70,26 @@ object LookupJoinCodeGenerator {
       functionName: String,
       fieldCopy: Boolean): GeneratedFunction[FlatMapFunction[RowData, 
RowData]] = {
 
-    val bodyCode: GeneratedExpression => String = call => {
-      val resultCollectorTerm = call.resultTerm
-      s"""
-         |$resultCollectorTerm.setCollector($DEFAULT_COLLECTOR_TERM);
-         |${call.code}
-         |""".stripMargin
-    }
-
-    generateLookupFunction(
-      classOf[FlatMapFunction[RowData, RowData]],
-      tableConfig,
-      classLoader,
-      dataTypeFactory,
-      inputType,
-      tableSourceType,
-      returnType,
-      lookupKeys,
-      classOf[TableFunction[_]],
-      syncLookupFunction,
-      functionName,
-      fieldCopy,
-      bodyCode
-    ).tableFunc
+    FunctionCallCodeGenerator
+      .generateSyncFunctionCall(
+        tableConfig,
+        classLoader,
+        dataTypeFactory,
+        inputType,
+        tableSourceType,
+        returnType,
+        lookupKeys,
+        syncLookupFunction,
+        generateCallWithDataType(
+          dataTypeFactory,
+          functionName,
+          tableSourceType,
+          classOf[TableFunction[_]]),
+        functionName,
+        "LookupFunction",
+        fieldCopy
+      )
+      .tableFunc
   }
 
   /** Generates a async lookup function ([[AsyncTableFunction]]) */
@@ -112,9 +103,7 @@ object LookupJoinCodeGenerator {
       lookupKeys: util.List[FunctionParam],
       asyncLookupFunction: AsyncTableFunction[_],
       functionName: String): 
GeneratedTableFunctionWithDataType[AsyncFunction[RowData, AnyRef]] = {
-
-    generateLookupFunction(
-      classOf[AsyncFunction[RowData, AnyRef]],
+    FunctionCallCodeGenerator.generateAsyncFunctionCall(
       tableConfig,
       classLoader,
       dataTypeFactory,
@@ -122,98 +111,66 @@ object LookupJoinCodeGenerator {
       tableSourceType,
       returnType,
       lookupKeys,
-      classOf[AsyncTableFunction[_]],
       asyncLookupFunction,
+      generateCallWithDataType(
+        dataTypeFactory,
+        functionName,
+        tableSourceType,
+        classOf[AsyncTableFunction[_]]),
       functionName,
-      fieldCopy = true, // always copy input field because of async buffer
-      _.code
+      "AsyncLookupFunction"
     )
   }
 
-  private def generateLookupFunction[F <: Function](
-      generatedClass: Class[F],
-      tableConfig: ReadableConfig,
-      classLoader: ClassLoader,
+  private def generateCallWithDataType(
       dataTypeFactory: DataTypeFactory,
-      inputType: LogicalType,
-      tableSourceType: LogicalType,
-      returnType: LogicalType,
-      lookupKeys: util.List[FunctionParam],
-      lookupFunctionBase: Class[_],
-      lookupFunction: UserDefinedFunction,
       functionName: String,
-      fieldCopy: Boolean,
-      bodyCode: GeneratedExpression => String): 
GeneratedTableFunctionWithDataType[F] = {
-
-    val callContext =
-      new LookupCallContext(dataTypeFactory, lookupFunction, inputType, 
lookupKeys, tableSourceType)
-
-    // create the final UDF for runtime
-    val udf = UserDefinedFunctionHelper.createSpecializedFunction(
-      functionName,
-      lookupFunction,
-      callContext,
-      classOf[PlannerBase].getClassLoader,
-      tableConfig,
-      // no need to support expression evaluation at this point
-      null)
-
-    val inference =
-      createLookupTypeInference(dataTypeFactory, callContext, 
lookupFunctionBase, udf, functionName)
-
-    val ctx = new CodeGeneratorContext(tableConfig, classLoader)
-    val operands = prepareOperands(ctx, inputType, lookupKeys, fieldCopy)
-
-    // TODO: filter all records when there are any nulls on the join key, 
because
-    //  "IS NOT DISTINCT FROM" is not supported yet.
-    // Note: AsyncPredictFunction or PredictFunction does not use Lookup 
Syntax.
-    val skipIfArgsNull = !lookupFunction.isInstanceOf[PredictFunction] && 
!lookupFunction
-      .isInstanceOf[AsyncPredictFunction]
-
-    val callWithDataType = 
BridgingFunctionGenUtil.generateFunctionAwareCallWithDataType(
-      ctx,
-      operands,
-      tableSourceType,
-      inference,
-      callContext,
-      udf,
-      functionName,
-      skipIfArgsNull = skipIfArgsNull
-    )
-
-    val function = FunctionCodeGenerator.generateFunction(
-      ctx,
-      "LookupFunction",
-      generatedClass,
-      bodyCode(callWithDataType._1),
-      returnType,
-      inputType)
-
-    GeneratedTableFunctionWithDataType(function, callWithDataType._2)
-  }
-
-  private def prepareOperands(
+      tableSourceType: LogicalType,
+      baseClass: Class[_]
+  ) = (
       ctx: CodeGeneratorContext,
-      inputType: LogicalType,
-      lookupKeys: util.List[FunctionParam],
-      fieldCopy: Boolean): Seq[GeneratedExpression] = {
+      callContext: FunctionCallContext,
+      udf: UserDefinedFunction,
+      operands: Seq[GeneratedExpression]) => {
+    def inferCallWithDataType(
+        ctx: CodeGeneratorContext,
+        callContext: FunctionCallContext,
+        udf: UserDefinedFunction,
+        operands: Seq[GeneratedExpression],
+        legacy: Boolean,
+        e: Exception = null): (GeneratedExpression, DataType) = {
+      val inference = createLookupTypeInference(
+        dataTypeFactory,
+        callContext,
+        baseClass,
+        udf,
+        functionName,
+        legacy,
+        e)
+
+      // TODO: filter all records when there is any nulls on the join key, 
because
+      //  "IS NOT DISTINCT FROM" is not supported yet.
+      val callWithDataType = 
BridgingFunctionGenUtil.generateFunctionAwareCallWithDataType(
+        ctx,
+        operands,
+        tableSourceType,
+        inference,
+        callContext,
+        udf,
+        functionName,
+        skipIfArgsNull = true
+      )
+      callWithDataType
+    }
 
-    lookupKeys.asScala
-      .map {
-        case constantKey: Constant =>
-          val res = RexLiteralUtil.toFlinkInternalValue(constantKey.literal)
-          generateLiteral(ctx, res.f0, res.f1)
-        case fieldKey: FieldRef =>
-          generateInputAccess(
-            ctx,
-            inputType,
-            DEFAULT_INPUT1_TERM,
-            fieldKey.index,
-            nullableInput = false,
-            fieldCopy)
-        case _ =>
-          throw new CodeGenException("Invalid lookup key.")
-      }
+    try {
+      // user provided type inference has precedence
+      // this ensures that all functions work in the same way
+      inferCallWithDataType(ctx, callContext, udf, operands, legacy = false)
+    } catch {
+      case e: Exception =>
+        inferCallWithDataType(ctx, callContext, udf, operands, legacy = true, 
e)
+    }
   }
 
   /**
@@ -225,66 +182,58 @@ object LookupJoinCodeGenerator {
    */
   private def createLookupTypeInference(
       dataTypeFactory: DataTypeFactory,
-      callContext: LookupCallContext,
+      callContext: FunctionCallContext,
       baseClass: Class[_],
       udf: UserDefinedFunction,
-      functionName: String): TypeInference = {
+      functionName: String,
+      legacy: Boolean,
+      e: Exception): TypeInference = {
 
-    try {
+    if (!legacy) {
       // user provided type inference has precedence
       // this ensures that all functions work in the same way
       udf.getTypeInference(dataTypeFactory)
-    } catch {
-      case e: Exception =>
-        // for convenience, we assume internal or default external data 
structures
-        // of expected logical types
-        val defaultArgDataTypes = callContext.getArgumentDataTypes.asScala
-        val defaultOutputDataType = callContext.getOutputDataType.get()
-
-        val outputClass =
-          if (
-            udf.isInstanceOf[LookupFunction] || 
udf.isInstanceOf[AsyncLookupFunction] || udf
-              .isInstanceOf[PredictFunction] || 
udf.isInstanceOf[AsyncPredictFunction]
-          ) {
-            Some(classOf[RowData])
-          } else {
-            toScala(extractSimpleGeneric(baseClass, udf.getClass, 0))
-          }
-        val (argDataTypes, outputDataType) = outputClass match {
-          case Some(c) if c == classOf[Row] =>
-            (defaultArgDataTypes, defaultOutputDataType)
-          case Some(c) if c == classOf[RowData] =>
-            val internalArgDataTypes = defaultArgDataTypes
-              .map(dt => transform(dt, TypeTransformations.TO_INTERNAL_CLASS))
-            val internalOutputDataType =
-              transform(defaultOutputDataType, 
TypeTransformations.TO_INTERNAL_CLASS)
-            (internalArgDataTypes, internalOutputDataType)
-          case _ =>
-            throw new ValidationException(
-              s"Could not determine a type inference for lookup function 
'$functionName'. " +
-                s"Lookup functions support regular type inference. However, 
for convenience, the " +
-                s"output class can simply be a ${classOf[Row].getSimpleName} 
or " +
-                s"${classOf[RowData].getSimpleName} class in which case the 
input and output " +
-                s"types are derived from the table's schema with default 
conversion.",
-              e)
+    } else {
+      // for convenience, we assume internal or default external data 
structures
+      // of expected logical types
+      val defaultArgDataTypes = callContext.getArgumentDataTypes.asScala
+      val defaultOutputDataType = callContext.getOutputDataType.get()
+
+      val outputClass =
+        if (udf.isInstanceOf[LookupFunction] || 
udf.isInstanceOf[AsyncLookupFunction]) {
+          Some(classOf[RowData])
+        } else {
+          toScala(extractSimpleGeneric(baseClass, udf.getClass, 0))
         }
+      val (argDataTypes, outputDataType) = outputClass match {
+        case Some(c) if c == classOf[Row] =>
+          (defaultArgDataTypes, defaultOutputDataType)
+        case Some(c) if c == classOf[RowData] =>
+          val internalArgDataTypes = defaultArgDataTypes
+            .map(dt => transform(dt, TypeTransformations.TO_INTERNAL_CLASS))
+          val internalOutputDataType =
+            transform(defaultOutputDataType, 
TypeTransformations.TO_INTERNAL_CLASS)
+          (internalArgDataTypes, internalOutputDataType)
+        case _ =>
+          throw new ValidationException(
+            s"Could not determine a type inference for lookup function 
'$functionName'. " +
+              s"Lookup functions support regular type inference. However, for 
convenience, the " +
+              s"output class can simply be a ${classOf[Row].getSimpleName} or 
" +
+              s"${classOf[RowData].getSimpleName} class in which case the 
input and output " +
+              s"types are derived from the table's schema with default 
conversion.",
+            e)
+      }
 
-        verifyFunctionAwareImplementation(argDataTypes, outputDataType, udf, 
functionName)
+      verifyFunctionAwareImplementation(argDataTypes, outputDataType, udf, 
functionName)
 
-        TypeInference
-          .newBuilder()
-          .typedArguments(argDataTypes.asJava)
-          .outputTypeStrategy(TypeStrategies.explicit(outputDataType))
-          .build()
+      TypeInference
+        .newBuilder()
+        .typedArguments(argDataTypes.asJava)
+        .outputTypeStrategy(TypeStrategies.explicit(outputDataType))
+        .build()
     }
   }
 
-  /**
-   * Generates collector for temporal join ([[Collector]])
-   *
-   * Differs from CommonCorrelate.generateCollector which has no real 
condition because of
-   * FLINK-7865, here we should deal with outer join type when real conditions 
filtered result.
-   */
   def generateCollector(
       ctx: CodeGeneratorContext,
       inputRowType: RowType,
@@ -293,122 +242,14 @@ object LookupJoinCodeGenerator {
       condition: Option[RexNode],
       pojoFieldMapping: Option[Array[Int]],
       retainHeader: Boolean = true): 
GeneratedCollector[ListenableCollector[RowData]] = {
-
-    val inputTerm = DEFAULT_INPUT1_TERM
-    val rightInputTerm = DEFAULT_INPUT2_TERM
-
-    val exprGenerator = new ExprCodeGenerator(ctx, nullableInput = false)
-      .bindInput(rightRowType, inputTerm = rightInputTerm, inputFieldMapping = 
pojoFieldMapping)
-
-    val rightResultExpr =
-      exprGenerator.generateConverterResultExpression(rightRowType, 
classOf[GenericRowData])
-
-    val joinedRowTerm = CodeGenUtils.newName(ctx, "joinedRow")
-    ctx.addReusableOutputRecord(resultRowType, classOf[JoinedRowData], 
joinedRowTerm)
-
-    val header = if (retainHeader) {
-      s"$joinedRowTerm.setRowKind($inputTerm.getRowKind());"
-    } else {
-      ""
-    }
-
-    val body =
-      s"""
-         |${rightResultExpr.code}
-         |$joinedRowTerm.replace($inputTerm, ${rightResultExpr.resultTerm});
-         |$header
-         |outputResult($joinedRowTerm);
-      """.stripMargin
-
-    val collectorCode = if (condition.isEmpty) {
-      body
-    } else {
-
-      val filterGenerator = new ExprCodeGenerator(ctx, nullableInput = false)
-        .bindInput(inputRowType, inputTerm)
-        .bindSecondInput(rightRowType, rightInputTerm, pojoFieldMapping)
-      val filterCondition = filterGenerator.generateExpression(condition.get)
-
-      s"""
-         |${filterCondition.code}
-         |if (${filterCondition.resultTerm}) {
-         |  $body
-         |}
-         |""".stripMargin
-    }
-
-    generateTableFunctionCollectorForJoinTable(
+    FunctionCallCodeGenerator.generateCollector(
       ctx,
-      "JoinTableFuncCollector",
-      collectorCode,
       inputRowType,
       rightRowType,
-      inputTerm = inputTerm,
-      collectedTerm = rightInputTerm)
-  }
-
-  /**
-   * The only differences against 
CollectorCodeGenerator.generateTableFunctionCollector is
-   * "super.collect" call is binding with collect join row in "body" code
-   */
-  private def generateTableFunctionCollectorForJoinTable(
-      ctx: CodeGeneratorContext,
-      name: String,
-      bodyCode: String,
-      inputType: RowType,
-      collectedType: RowType,
-      inputTerm: String = DEFAULT_INPUT1_TERM,
-      collectedTerm: String = DEFAULT_INPUT2_TERM)
-      : GeneratedCollector[ListenableCollector[RowData]] = {
-
-    val funcName = newName(ctx, name)
-    val input1TypeClass = boxedTypeTermForType(inputType)
-    val input2TypeClass = boxedTypeTermForType(collectedType)
-
-    val funcCode =
-      s"""
-      public class $funcName extends 
${classOf[ListenableCollector[_]].getCanonicalName} {
-
-        ${ctx.reuseMemberCode()}
-
-        public $funcName(Object[] references) throws Exception {
-          ${ctx.reuseInitCode()}
-        }
-
-        @Override
-        public void open(${className[OpenContext]} openContext) throws 
Exception {
-          ${ctx.reuseOpenCode()}
-        }
-
-        @Override
-        public void collect(Object record) throws Exception {
-          $input1TypeClass $inputTerm = ($input1TypeClass) getInput();
-          $input2TypeClass $collectedTerm = ($input2TypeClass) record;
-
-          // callback only when collectListener exists, equivalent to:
-          // getCollectListener().ifPresent(
-          //   listener -> ((CollectListener) listener).onCollect(record));
-          // TODO we should update code splitter's grammar file to accept 
lambda expressions.
-
-          if (getCollectListener().isPresent()) {
-             ((${classOf[CollectListener[_]].getCanonicalName}) 
getCollectListener().get())
-             .onCollect(record);
-          }
-
-          ${ctx.reuseLocalVariableCode()}
-          ${ctx.reuseInputUnboxingCode()}
-          ${ctx.reusePerRecordCode()}
-          $bodyCode
-        }
-
-        @Override
-        public void close() throws Exception {
-          ${ctx.reuseCloseCode()}
-        }
-      }
-    """.stripMargin
-
-    new GeneratedCollector(funcName, funcCode, ctx.references.toArray, 
ctx.tableConfig)
+      resultRowType,
+      condition,
+      pojoFieldMapping,
+      retainHeader)
   }
 
   /**
diff --git 
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/MLPredictCodeGenerator.scala
 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/MLPredictCodeGenerator.scala
new file mode 100644
index 00000000000..68d322f8275
--- /dev/null
+++ 
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/MLPredictCodeGenerator.scala
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.flink.table.planner.codegen
+
+import org.apache.flink.api.common.functions.FlatMapFunction
+import org.apache.flink.configuration.ReadableConfig
+import org.apache.flink.streaming.api.functions.async.AsyncFunction
+import org.apache.flink.table.catalog.DataTypeFactory
+import org.apache.flink.table.data.RowData
+import org.apache.flink.table.functions.{AsyncTableFunction, TableFunction, 
UserDefinedFunction}
+import 
org.apache.flink.table.planner.codegen.FunctionCallCodeGenerator.GeneratedTableFunctionWithDataType
+import org.apache.flink.table.planner.codegen.calls.BridgingFunctionGenUtil
+import org.apache.flink.table.planner.functions.inference.FunctionCallContext
+import org.apache.flink.table.planner.plan.utils.FunctionCallUtil.FunctionParam
+import org.apache.flink.table.runtime.collector.ListenableCollector
+import org.apache.flink.table.runtime.generated.{GeneratedCollector, 
GeneratedFunction}
+import org.apache.flink.table.types.inference.{TypeInference, TypeStrategies, 
TypeTransformations}
+import org.apache.flink.table.types.logical.{LogicalType, RowType}
+import org.apache.flink.table.types.utils.DataTypeUtils.transform
+
+import java.util
+
+import scala.collection.JavaConverters._
+
+object MLPredictCodeGenerator {
+
+  /** Generates a predict function ([[TableFunction]]) */
+  def generateSyncPredictFunction(
+      tableConfig: ReadableConfig,
+      classLoader: ClassLoader,
+      dataTypeFactory: DataTypeFactory,
+      inputType: LogicalType,
+      predictFunctionOutputType: LogicalType,
+      collectorOutputType: LogicalType,
+      features: util.List[FunctionParam],
+      syncPredictFunction: TableFunction[_],
+      functionName: String,
+      fieldCopy: Boolean
+  ): GeneratedFunction[FlatMapFunction[RowData, RowData]] = {
+    FunctionCallCodeGenerator
+      .generateSyncFunctionCall(
+        tableConfig,
+        classLoader,
+        dataTypeFactory,
+        inputType,
+        predictFunctionOutputType,
+        collectorOutputType,
+        features,
+        syncPredictFunction,
+        generateCallWithDataType(functionName, predictFunctionOutputType),
+        functionName,
+        "PredictFunction",
+        fieldCopy
+      )
+      .tableFunc
+  }
+
+  /** Generates a async predict function ([[AsyncTableFunction]]) */
+  def generateAsyncPredictFunction(
+      tableConfig: ReadableConfig,
+      classLoader: ClassLoader,
+      dataTypeFactory: DataTypeFactory,
+      inputType: LogicalType,
+      predictFunctionOutputType: LogicalType,
+      collectorOutputType: LogicalType,
+      features: util.List[FunctionParam],
+      asyncPredictFunction: AsyncTableFunction[_],
+      functionName: String): 
GeneratedTableFunctionWithDataType[AsyncFunction[RowData, AnyRef]] = {
+    FunctionCallCodeGenerator.generateAsyncFunctionCall(
+      tableConfig,
+      classLoader,
+      dataTypeFactory,
+      inputType,
+      predictFunctionOutputType,
+      collectorOutputType,
+      features,
+      asyncPredictFunction,
+      generateCallWithDataType(functionName, predictFunctionOutputType),
+      functionName,
+      "AsyncPredictFunction"
+    )
+  }
+
+  /** Generate a collector to collect to join the input row and predicted 
results. */
+  def generateCollector(
+      ctx: CodeGeneratorContext,
+      inputRowType: RowType,
+      predictFunctionOutputType: RowType,
+      collectorOutputType: RowType
+  ): GeneratedCollector[ListenableCollector[RowData]] = {
+    FunctionCallCodeGenerator.generateCollector(
+      ctx,
+      inputRowType,
+      predictFunctionOutputType,
+      collectorOutputType,
+      Option.empty,
+      Option.empty
+    )
+  }
+
+  private def generateCallWithDataType(
+      functionName: String,
+      modelOutputType: LogicalType
+  ) = (
+      ctx: CodeGeneratorContext,
+      callContext: FunctionCallContext,
+      udf: UserDefinedFunction,
+      operands: Seq[GeneratedExpression]) => {
+    val inference = TypeInference
+      .newBuilder()
+      .typedArguments(
+        callContext.getArgumentDataTypes.asScala
+          .map(dt => transform(dt, TypeTransformations.TO_INTERNAL_CLASS))
+          .asJava)
+      .outputTypeStrategy(TypeStrategies.explicit(
+        transform(callContext.getOutputDataType.get(), 
TypeTransformations.TO_INTERNAL_CLASS)))
+      .build()
+    BridgingFunctionGenUtil.generateFunctionAwareCallWithDataType(
+      ctx,
+      operands,
+      modelOutputType,
+      inference,
+      callContext,
+      udf,
+      functionName,
+      skipIfArgsNull = false
+    )
+  }
+}
diff --git 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/calc/async/AsyncFunctionRunner.java
 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/AbstractAsyncFunctionRunner.java
similarity index 62%
copy from 
flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/calc/async/AsyncFunctionRunner.java
copy to 
flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/AbstractAsyncFunctionRunner.java
index 456f07bcd96..10e4defcfb4 100644
--- 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/calc/async/AsyncFunctionRunner.java
+++ 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/AbstractAsyncFunctionRunner.java
@@ -16,30 +16,29 @@
  * limitations under the License.
  */
 
-package org.apache.flink.table.runtime.operators.calc.async;
+package org.apache.flink.table.runtime.operators;
 
 import org.apache.flink.api.common.functions.OpenContext;
 import org.apache.flink.api.common.functions.util.FunctionUtils;
 import org.apache.flink.streaming.api.functions.async.AsyncFunction;
-import org.apache.flink.streaming.api.functions.async.ResultFuture;
 import org.apache.flink.streaming.api.functions.async.RichAsyncFunction;
 import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.functions.AsyncLookupFunction;
+import org.apache.flink.table.functions.AsyncPredictFunction;
 import org.apache.flink.table.runtime.generated.GeneratedFunction;
 
 /**
- * Async function runner for {@link 
org.apache.flink.table.functions.AsyncScalarFunction}, which
- * takes the generated function, instantiates it, and then calls its lifecycle 
methods.
+ * Base function runner for specialized table function, e.g. {@link 
AsyncLookupFunction} or {@link
+ * AsyncPredictFunction}.
  */
-public class AsyncFunctionRunner extends RichAsyncFunction<RowData, RowData> {
+public abstract class AbstractAsyncFunctionRunner<T> extends 
RichAsyncFunction<RowData, RowData> {
 
-    private static final long serialVersionUID = -7198305381139008806L;
+    protected final GeneratedFunction<AsyncFunction<RowData, T>> 
generatedFetcher;
 
-    private final GeneratedFunction<AsyncFunction<RowData, RowData>> 
generatedFetcher;
+    protected transient AsyncFunction<RowData, T> fetcher;
 
-    private transient AsyncFunction<RowData, RowData> fetcher;
-
-    public AsyncFunctionRunner(
-            GeneratedFunction<AsyncFunction<RowData, RowData>> 
generatedFetcher) {
+    public AbstractAsyncFunctionRunner(
+            GeneratedFunction<AsyncFunction<RowData, T>> generatedFetcher) {
         this.generatedFetcher = generatedFetcher;
     }
 
@@ -51,18 +50,11 @@ public class AsyncFunctionRunner extends 
RichAsyncFunction<RowData, RowData> {
         FunctionUtils.openFunction(fetcher, openContext);
     }
 
-    @Override
-    public void asyncInvoke(RowData input, ResultFuture<RowData> resultFuture) 
{
-        try {
-            fetcher.asyncInvoke(input, resultFuture);
-        } catch (Throwable t) {
-            resultFuture.completeExceptionally(t);
-        }
-    }
-
     @Override
     public void close() throws Exception {
         super.close();
-        FunctionUtils.closeFunction(fetcher);
+        if (fetcher != null) {
+            FunctionUtils.closeFunction(fetcher);
+        }
     }
 }
diff --git 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/calc/async/AsyncFunctionRunner.java
 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/AbstractFunctionRunner.java
similarity index 53%
copy from 
flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/calc/async/AsyncFunctionRunner.java
copy to 
flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/AbstractFunctionRunner.java
index 456f07bcd96..126aeba2204 100644
--- 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/calc/async/AsyncFunctionRunner.java
+++ 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/AbstractFunctionRunner.java
@@ -16,53 +16,45 @@
  * limitations under the License.
  */
 
-package org.apache.flink.table.runtime.operators.calc.async;
+package org.apache.flink.table.runtime.operators;
 
+import org.apache.flink.api.common.functions.FlatMapFunction;
 import org.apache.flink.api.common.functions.OpenContext;
 import org.apache.flink.api.common.functions.util.FunctionUtils;
-import org.apache.flink.streaming.api.functions.async.AsyncFunction;
-import org.apache.flink.streaming.api.functions.async.ResultFuture;
-import org.apache.flink.streaming.api.functions.async.RichAsyncFunction;
+import org.apache.flink.streaming.api.functions.ProcessFunction;
 import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.functions.LookupFunction;
 import org.apache.flink.table.runtime.generated.GeneratedFunction;
 
 /**
- * Async function runner for {@link 
org.apache.flink.table.functions.AsyncScalarFunction}, which
- * takes the generated function, instantiates it, and then calls its lifecycle 
methods.
+ * Base function runner for specialized table function, e.g. {@link 
LookupFunction} or {@link
+ * ProcessFunction}.
  */
-public class AsyncFunctionRunner extends RichAsyncFunction<RowData, RowData> {
+public abstract class AbstractFunctionRunner extends ProcessFunction<RowData, 
RowData> {
 
-    private static final long serialVersionUID = -7198305381139008806L;
+    private final GeneratedFunction<FlatMapFunction<RowData, RowData>> 
generatedFetcher;
 
-    private final GeneratedFunction<AsyncFunction<RowData, RowData>> 
generatedFetcher;
+    protected transient FlatMapFunction<RowData, RowData> fetcher;
 
-    private transient AsyncFunction<RowData, RowData> fetcher;
-
-    public AsyncFunctionRunner(
-            GeneratedFunction<AsyncFunction<RowData, RowData>> 
generatedFetcher) {
+    public AbstractFunctionRunner(
+            GeneratedFunction<FlatMapFunction<RowData, RowData>> 
generatedFetcher) {
         this.generatedFetcher = generatedFetcher;
     }
 
     @Override
     public void open(OpenContext openContext) throws Exception {
         super.open(openContext);
-        fetcher = 
generatedFetcher.newInstance(getRuntimeContext().getUserCodeClassLoader());
+        this.fetcher = 
generatedFetcher.newInstance(getRuntimeContext().getUserCodeClassLoader());
+
         FunctionUtils.setFunctionRuntimeContext(fetcher, getRuntimeContext());
         FunctionUtils.openFunction(fetcher, openContext);
     }
 
-    @Override
-    public void asyncInvoke(RowData input, ResultFuture<RowData> resultFuture) 
{
-        try {
-            fetcher.asyncInvoke(input, resultFuture);
-        } catch (Throwable t) {
-            resultFuture.completeExceptionally(t);
-        }
-    }
-
     @Override
     public void close() throws Exception {
+        if (fetcher != null) {
+            FunctionUtils.closeFunction(fetcher);
+        }
         super.close();
-        FunctionUtils.closeFunction(fetcher);
     }
 }
diff --git 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/calc/async/AsyncFunctionRunner.java
 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/calc/async/AsyncFunctionRunner.java
index 456f07bcd96..7cc3a7273f7 100644
--- 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/calc/async/AsyncFunctionRunner.java
+++ 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/calc/async/AsyncFunctionRunner.java
@@ -18,37 +18,23 @@
 
 package org.apache.flink.table.runtime.operators.calc.async;
 
-import org.apache.flink.api.common.functions.OpenContext;
-import org.apache.flink.api.common.functions.util.FunctionUtils;
 import org.apache.flink.streaming.api.functions.async.AsyncFunction;
 import org.apache.flink.streaming.api.functions.async.ResultFuture;
-import org.apache.flink.streaming.api.functions.async.RichAsyncFunction;
 import org.apache.flink.table.data.RowData;
 import org.apache.flink.table.runtime.generated.GeneratedFunction;
+import org.apache.flink.table.runtime.operators.AbstractAsyncFunctionRunner;
 
 /**
  * Async function runner for {@link 
org.apache.flink.table.functions.AsyncScalarFunction}, which
  * takes the generated function, instantiates it, and then calls its lifecycle 
methods.
  */
-public class AsyncFunctionRunner extends RichAsyncFunction<RowData, RowData> {
+public class AsyncFunctionRunner extends AbstractAsyncFunctionRunner<RowData> {
 
     private static final long serialVersionUID = -7198305381139008806L;
 
-    private final GeneratedFunction<AsyncFunction<RowData, RowData>> 
generatedFetcher;
-
-    private transient AsyncFunction<RowData, RowData> fetcher;
-
     public AsyncFunctionRunner(
             GeneratedFunction<AsyncFunction<RowData, RowData>> 
generatedFetcher) {
-        this.generatedFetcher = generatedFetcher;
-    }
-
-    @Override
-    public void open(OpenContext openContext) throws Exception {
-        super.open(openContext);
-        fetcher = 
generatedFetcher.newInstance(getRuntimeContext().getUserCodeClassLoader());
-        FunctionUtils.setFunctionRuntimeContext(fetcher, getRuntimeContext());
-        FunctionUtils.openFunction(fetcher, openContext);
+        super(generatedFetcher);
     }
 
     @Override
@@ -59,10 +45,4 @@ public class AsyncFunctionRunner extends 
RichAsyncFunction<RowData, RowData> {
             resultFuture.completeExceptionally(t);
         }
     }
-
-    @Override
-    public void close() throws Exception {
-        super.close();
-        FunctionUtils.closeFunction(fetcher);
-    }
 }
diff --git 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/AsyncLookupJoinRunner.java
 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/AsyncLookupJoinRunner.java
index 310f38d5489..6b34c99f4a7 100644
--- 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/AsyncLookupJoinRunner.java
+++ 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/AsyncLookupJoinRunner.java
@@ -26,7 +26,6 @@ import org.apache.flink.configuration.Configuration;
 import org.apache.flink.streaming.api.functions.async.AsyncFunction;
 import org.apache.flink.streaming.api.functions.async.CollectionSupplier;
 import org.apache.flink.streaming.api.functions.async.ResultFuture;
-import org.apache.flink.streaming.api.functions.async.RichAsyncFunction;
 import org.apache.flink.table.data.GenericRowData;
 import org.apache.flink.table.data.RowData;
 import org.apache.flink.table.data.conversion.DataStructureConverter;
@@ -35,6 +34,7 @@ import 
org.apache.flink.table.runtime.collector.TableFunctionResultFuture;
 import org.apache.flink.table.runtime.generated.FilterCondition;
 import org.apache.flink.table.runtime.generated.GeneratedFunction;
 import org.apache.flink.table.runtime.generated.GeneratedResultFuture;
+import org.apache.flink.table.runtime.operators.AbstractAsyncFunctionRunner;
 import org.apache.flink.table.runtime.typeutils.RowDataSerializer;
 
 import java.util.ArrayList;
@@ -45,10 +45,9 @@ import java.util.concurrent.ArrayBlockingQueue;
 import java.util.concurrent.BlockingQueue;
 
 /** The async join runner to lookup the dimension table. */
-public class AsyncLookupJoinRunner extends RichAsyncFunction<RowData, RowData> 
{
+public class AsyncLookupJoinRunner extends AbstractAsyncFunctionRunner<Object> 
{
     private static final long serialVersionUID = -6664660022391632480L;
 
-    private final GeneratedFunction<AsyncFunction<RowData, Object>> 
generatedFetcher;
     private final DataStructureConverter<RowData, Object> fetcherConverter;
     private final GeneratedResultFuture<TableFunctionResultFuture<RowData>> 
generatedResultFuture;
     private final GeneratedFunction<FilterCondition> 
generatedPreFilterCondition;
@@ -56,8 +55,6 @@ public class AsyncLookupJoinRunner extends 
RichAsyncFunction<RowData, RowData> {
     private final boolean isLeftOuterJoin;
     private final int asyncBufferCapacity;
 
-    private transient AsyncFunction<RowData, Object> fetcher;
-
     protected final RowDataSerializer rightRowSerializer;
 
     /**
@@ -83,7 +80,7 @@ public class AsyncLookupJoinRunner extends 
RichAsyncFunction<RowData, RowData> {
             RowDataSerializer rightRowSerializer,
             boolean isLeftOuterJoin,
             int asyncBufferCapacity) {
-        this.generatedFetcher = generatedFetcher;
+        super(generatedFetcher);
         this.fetcherConverter = fetcherConverter;
         this.generatedResultFuture = generatedResultFuture;
         this.generatedPreFilterCondition = generatedPreFilterCondition;
@@ -96,11 +93,9 @@ public class AsyncLookupJoinRunner extends 
RichAsyncFunction<RowData, RowData> {
     public void open(OpenContext openContext) throws Exception {
         super.open(openContext);
         ClassLoader cl = getRuntimeContext().getUserCodeClassLoader();
-        this.fetcher = generatedFetcher.newInstance(cl);
         this.preFilterCondition = generatedPreFilterCondition.newInstance(cl);
         FunctionUtils.setFunctionRuntimeContext(fetcher, getRuntimeContext());
         FunctionUtils.setFunctionRuntimeContext(preFilterCondition, 
getRuntimeContext());
-        FunctionUtils.openFunction(fetcher, openContext);
         FunctionUtils.openFunction(preFilterCondition, openContext);
 
         // try to compile the generated ResultFuture, fail fast if the code is 
corrupt.
@@ -152,9 +147,6 @@ public class AsyncLookupJoinRunner extends 
RichAsyncFunction<RowData, RowData> {
     @Override
     public void close() throws Exception {
         super.close();
-        if (fetcher != null) {
-            FunctionUtils.closeFunction(fetcher);
-        }
         if (preFilterCondition != null) {
             FunctionUtils.closeFunction(preFilterCondition);
         }
diff --git 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/LookupJoinRunner.java
 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/LookupJoinRunner.java
index c91a9feb7bf..4ecb0b53f59 100644
--- 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/LookupJoinRunner.java
+++ 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/join/lookup/LookupJoinRunner.java
@@ -21,7 +21,6 @@ package org.apache.flink.table.runtime.operators.join.lookup;
 import org.apache.flink.api.common.functions.FlatMapFunction;
 import org.apache.flink.api.common.functions.OpenContext;
 import org.apache.flink.api.common.functions.util.FunctionUtils;
-import org.apache.flink.streaming.api.functions.ProcessFunction;
 import org.apache.flink.table.data.GenericRowData;
 import org.apache.flink.table.data.RowData;
 import org.apache.flink.table.data.utils.JoinedRowData;
@@ -29,20 +28,19 @@ import 
org.apache.flink.table.runtime.collector.ListenableCollector;
 import org.apache.flink.table.runtime.generated.FilterCondition;
 import org.apache.flink.table.runtime.generated.GeneratedCollector;
 import org.apache.flink.table.runtime.generated.GeneratedFunction;
+import org.apache.flink.table.runtime.operators.AbstractFunctionRunner;
 import org.apache.flink.util.Collector;
 
 /** The join runner to lookup the dimension table. */
-public class LookupJoinRunner extends ProcessFunction<RowData, RowData> {
+public class LookupJoinRunner extends AbstractFunctionRunner {
     private static final long serialVersionUID = -4521543015709964733L;
 
-    private final GeneratedFunction<FlatMapFunction<RowData, RowData>> 
generatedFetcher;
     private final GeneratedCollector<ListenableCollector<RowData>> 
generatedCollector;
     private final GeneratedFunction<FilterCondition> 
generatedPreFilterCondition;
 
     protected final boolean isLeftOuterJoin;
     protected final int tableFieldsCount;
 
-    private transient FlatMapFunction<RowData, RowData> fetcher;
     protected transient ListenableCollector<RowData> collector;
     protected transient JoinedRowData outRow;
     protected transient FilterCondition preFilterCondition;
@@ -54,7 +52,7 @@ public class LookupJoinRunner extends 
ProcessFunction<RowData, RowData> {
             GeneratedFunction<FilterCondition> generatedPreFilterCondition,
             boolean isLeftOuterJoin,
             int tableFieldsCount) {
-        this.generatedFetcher = generatedFetcher;
+        super(generatedFetcher);
         this.generatedCollector = generatedCollector;
         this.generatedPreFilterCondition = generatedPreFilterCondition;
         this.isLeftOuterJoin = isLeftOuterJoin;
@@ -64,17 +62,14 @@ public class LookupJoinRunner extends 
ProcessFunction<RowData, RowData> {
     @Override
     public void open(OpenContext openContext) throws Exception {
         super.open(openContext);
-        this.fetcher = 
generatedFetcher.newInstance(getRuntimeContext().getUserCodeClassLoader());
         this.collector =
                 
generatedCollector.newInstance(getRuntimeContext().getUserCodeClassLoader());
         this.preFilterCondition =
                 generatedPreFilterCondition.newInstance(
                         getRuntimeContext().getUserCodeClassLoader());
 
-        FunctionUtils.setFunctionRuntimeContext(fetcher, getRuntimeContext());
         FunctionUtils.setFunctionRuntimeContext(collector, 
getRuntimeContext());
         FunctionUtils.setFunctionRuntimeContext(preFilterCondition, 
getRuntimeContext());
-        FunctionUtils.openFunction(fetcher, openContext);
         FunctionUtils.openFunction(collector, openContext);
         FunctionUtils.openFunction(preFilterCondition, openContext);
 
@@ -124,9 +119,6 @@ public class LookupJoinRunner extends 
ProcessFunction<RowData, RowData> {
 
     @Override
     public void close() throws Exception {
-        if (fetcher != null) {
-            FunctionUtils.closeFunction(fetcher);
-        }
         if (collector != null) {
             FunctionUtils.closeFunction(collector);
         }
diff --git 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/ml/AsyncMLPredictRunner.java
 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/ml/AsyncMLPredictRunner.java
new file mode 100644
index 00000000000..aa77eedaeaf
--- /dev/null
+++ 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/ml/AsyncMLPredictRunner.java
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.operators.ml;
+
+import org.apache.flink.api.common.functions.OpenContext;
+import org.apache.flink.metrics.MetricGroup;
+import org.apache.flink.streaming.api.functions.async.AsyncFunction;
+import org.apache.flink.streaming.api.functions.async.CollectionSupplier;
+import org.apache.flink.streaming.api.functions.async.ResultFuture;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.data.utils.JoinedRowData;
+import org.apache.flink.table.functions.AsyncPredictFunction;
+import org.apache.flink.table.runtime.generated.GeneratedFunction;
+import org.apache.flink.table.runtime.operators.AbstractAsyncFunctionRunner;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.concurrent.ArrayBlockingQueue;
+import java.util.concurrent.BlockingQueue;
+
+/**
+ * Async function runner for {@link AsyncPredictFunction}, which takes the 
generated function,
+ * instantiates it, and then calls its lifecycle methods.
+ */
+public class AsyncMLPredictRunner extends AbstractAsyncFunctionRunner<RowData> 
{
+
+    private final int asyncBufferCapacity;
+
+    /**
+     * Buffers {@link ResultFuture} to avoid newInstance cost when processing 
elements every time.
+     * We use {@link BlockingQueue} to make sure the head {@link 
ResultFuture}s are available.
+     */
+    private transient BlockingQueue<JoinedRowResultFuture> resultFutureBuffer;
+
+    public AsyncMLPredictRunner(
+            GeneratedFunction<AsyncFunction<RowData, RowData>> 
generatedFetcher,
+            int asyncBufferCapacity) {
+        super(generatedFetcher);
+        this.asyncBufferCapacity = asyncBufferCapacity;
+    }
+
+    @Override
+    public void open(OpenContext openContext) throws Exception {
+        super.open(openContext);
+        this.resultFutureBuffer = new ArrayBlockingQueue<>(asyncBufferCapacity 
+ 1);
+        for (int i = 0; i < asyncBufferCapacity + 1; i++) {
+            JoinedRowResultFuture rf = new 
JoinedRowResultFuture(resultFutureBuffer);
+            // add will throw exception immediately if the queue is full which 
should never happen
+            resultFutureBuffer.add(rf);
+        }
+        registerMetric(getRuntimeContext().getMetricGroup());
+    }
+
+    @Override
+    public void asyncInvoke(RowData input, ResultFuture<RowData> resultFuture) 
throws Exception {
+        try {
+            JoinedRowResultFuture buffer = resultFutureBuffer.take();
+            buffer.reset(input, resultFuture);
+            fetcher.asyncInvoke(input, buffer);
+        } catch (Throwable t) {
+            resultFuture.completeExceptionally(t);
+        }
+    }
+
+    private void registerMetric(MetricGroup metricGroup) {
+        metricGroup.gauge(
+                "ai_queue_length", () -> asyncBufferCapacity + 1 - 
resultFutureBuffer.size());
+        metricGroup.gauge("ai_queue_capacity", () -> asyncBufferCapacity);
+        metricGroup.gauge(
+                "ai_queue_usage_ratio",
+                () ->
+                        1.0
+                                * (asyncBufferCapacity + 1 - 
resultFutureBuffer.size())
+                                / asyncBufferCapacity);
+    }
+
+    private static final class JoinedRowResultFuture implements 
ResultFuture<RowData> {
+
+        private final BlockingQueue<JoinedRowResultFuture> resultFutureBuffer;
+
+        private ResultFuture<RowData> realOutput;
+        private RowData leftRow;
+
+        public JoinedRowResultFuture(BlockingQueue<JoinedRowResultFuture> 
resultFutureBuffer) {
+            this.resultFutureBuffer = resultFutureBuffer;
+        }
+
+        public void reset(RowData row, ResultFuture<RowData> realOutput) {
+            this.realOutput = realOutput;
+            this.leftRow = row;
+        }
+
+        @Override
+        public void complete(Collection<RowData> result) {
+            List<RowData> outRows = new ArrayList<>();
+            for (RowData rightRow : result) {
+                RowData outRow = new JoinedRowData(leftRow.getRowKind(), 
leftRow, rightRow);
+                outRows.add(outRow);
+            }
+            realOutput.complete(outRows);
+
+            try {
+                // put this collector to the queue to avoid this collector is 
used
+                // again before outRows in the collector is not consumed.
+                resultFutureBuffer.put(this);
+            } catch (InterruptedException e) {
+                completeExceptionally(e);
+            }
+        }
+
+        @Override
+        public void completeExceptionally(Throwable error) {
+            realOutput.completeExceptionally(error);
+        }
+
+        @Override
+        public void complete(CollectionSupplier<RowData> supplier) {
+            throw new UnsupportedOperationException();
+        }
+    }
+}
diff --git 
a/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/ml/MLPredictRunner.java
 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/ml/MLPredictRunner.java
new file mode 100644
index 00000000000..a8a486911f8
--- /dev/null
+++ 
b/flink-table/flink-table-runtime/src/main/java/org/apache/flink/table/runtime/operators/ml/MLPredictRunner.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.table.runtime.operators.ml;
+
+import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.common.functions.OpenContext;
+import org.apache.flink.api.common.functions.util.FunctionUtils;
+import org.apache.flink.streaming.api.functions.ProcessFunction;
+import org.apache.flink.table.data.RowData;
+import org.apache.flink.table.functions.PredictFunction;
+import org.apache.flink.table.runtime.collector.ListenableCollector;
+import org.apache.flink.table.runtime.generated.GeneratedCollector;
+import org.apache.flink.table.runtime.generated.GeneratedFunction;
+import org.apache.flink.table.runtime.operators.AbstractFunctionRunner;
+import org.apache.flink.util.Collector;
+
+/**
+ * Function runner for {@link PredictFunction}, which takes the generated 
function, instantiates it,
+ * and then calls its lifecycle methods.
+ */
+public class MLPredictRunner extends AbstractFunctionRunner {
+
+    private final GeneratedCollector<ListenableCollector<RowData>> 
generatedCollector;
+
+    protected transient ListenableCollector<RowData> collector;
+
+    public MLPredictRunner(
+            GeneratedFunction<FlatMapFunction<RowData, RowData>> 
generatedFetcher,
+            GeneratedCollector<ListenableCollector<RowData>> 
generatedCollector) {
+        super(generatedFetcher);
+        this.generatedCollector = generatedCollector;
+    }
+
+    @Override
+    public void open(OpenContext openContext) throws Exception {
+        super.open(openContext);
+
+        this.collector =
+                
generatedCollector.newInstance(getRuntimeContext().getUserCodeClassLoader());
+        FunctionUtils.setFunctionRuntimeContext(collector, 
getRuntimeContext());
+        FunctionUtils.openFunction(collector, openContext);
+    }
+
+    @Override
+    public void processElement(
+            RowData in, ProcessFunction<RowData, RowData>.Context ctx, 
Collector<RowData> out)
+            throws Exception {
+        prepareCollector(in, out);
+        fetcher.flatMap(in, collector);
+    }
+
+    public void prepareCollector(RowData in, Collector<RowData> out) {
+        collector.setCollector(out);
+        collector.setInput(in);
+        collector.reset();
+    }
+}

Reply via email to