This is an automated email from the ASF dual-hosted git repository.
snuyanzin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/master by this push:
new 5da2287dac0 [FLINK-39268][table] Expand and reuse local refs in
`CalcCodeGenerator`
5da2287dac0 is described below
commit 5da2287dac001c21eb0cba4ee9bb073c1c410347
Author: Sergey Nuyanzin <[email protected]>
AuthorDate: Thu May 7 18:40:59 2026 +0200
[FLINK-39268][table] Expand and reuse local refs in `CalcCodeGenerator`
---
.../functions/sql/SqlJsonArrayFunctionWrapper.java | 2 +-
.../sql/SqlJsonObjectFunctionWrapper.java | 2 +-
.../functions/sql/SqlJsonQueryFunctionWrapper.java | 2 +-
.../functions/sql/SqlJsonValueFunctionWrapper.java | 2 +-
.../plan/nodes/exec/common/CommonExecCalc.java | 4 +-
.../nodes/exec/common/CommonExecLookupJoin.java | 6 +-
.../plan/nodes/exec/spec/DeltaJoinTree.java | 3 +-
.../nodes/exec/stream/StreamExecDeltaJoin.java | 3 +-
.../flink/table/planner/utils/ShortcutUtils.java | 51 +-
.../table/planner/codegen/CalcCodeGenerator.scala | 72 ++-
.../flink/table/planner/codegen/CodeGenUtils.scala | 10 +
.../planner/codegen/CodeGeneratorContext.scala | 89 +++-
.../table/planner/codegen/ExprCodeGenerator.scala | 133 ++++-
.../table/planner/codegen/ExpressionReducer.scala | 9 +-
.../planner/codegen/FunctionCodeGenerator.scala | 3 +-
.../table/planner/codegen/JsonGenerateUtils.scala | 104 ++--
.../planner/codegen/LongHashJoinGenerator.scala | 2 +-
.../planner/codegen/LookupJoinCodeGenerator.scala | 18 +-
.../codegen/calls/BridgingFunctionGenUtil.scala | 7 +-
.../codegen/calls/BridgingSqlFunctionCallGen.scala | 6 +-
.../planner/codegen/calls/JsonArrayCallGen.scala | 11 +-
.../planner/codegen/calls/JsonObjectCallGen.scala | 10 +-
.../planner/codegen/calls/JsonStringCallGen.scala | 10 +-
.../planner/codegen/calls/SearchOperatorGen.scala | 14 +-
.../table/planner/plan/utils/FlinkRexUtil.scala | 8 +
.../planner/functions/JsonFunctionsITCase.java | 202 +++++++-
.../planner/runtime/stream/sql/FunctionITCase.java | 561 ++++++++++++++++++---
27 files changed, 1132 insertions(+), 212 deletions(-)
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonArrayFunctionWrapper.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonArrayFunctionWrapper.java
index e1b60699d7f..eb6aae3a8ed 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonArrayFunctionWrapper.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonArrayFunctionWrapper.java
@@ -29,7 +29,7 @@ import static
org.apache.flink.table.planner.plan.type.FlinkReturnTypes.VARCHAR_
* This class is a wrapper class for the {@link SqlJsonArrayFunction} but
using the {@code
* VARCHAR_NOT_NULL} return type inference.
*/
-class SqlJsonArrayFunctionWrapper extends SqlJsonArrayFunction {
+public class SqlJsonArrayFunctionWrapper extends SqlJsonArrayFunction {
@Override
public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonObjectFunctionWrapper.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonObjectFunctionWrapper.java
index b09ab149a64..b4ef34b94c5 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonObjectFunctionWrapper.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonObjectFunctionWrapper.java
@@ -29,7 +29,7 @@ import static
org.apache.flink.table.planner.plan.type.FlinkReturnTypes.VARCHAR_
* This class is a wrapper class for the {@link SqlJsonObjectFunction} but
using the {@code
* VARCHAR_NOT_NULL} return type inference.
*/
-class SqlJsonObjectFunctionWrapper extends SqlJsonObjectFunction {
+public class SqlJsonObjectFunctionWrapper extends SqlJsonObjectFunction {
@Override
public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonQueryFunctionWrapper.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonQueryFunctionWrapper.java
index 7a145ba9cce..ddae97fa232 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonQueryFunctionWrapper.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonQueryFunctionWrapper.java
@@ -42,7 +42,7 @@ import static
org.apache.flink.table.planner.plan.type.FlinkReturnTypes.VARCHAR_
* This class is a wrapper class for the {@link SqlJsonQueryFunction} but
using the {@code
* VARCHAR_FORCE_NULLABLE} return type inference.
*/
-class SqlJsonQueryFunctionWrapper extends SqlJsonQueryFunction {
+public class SqlJsonQueryFunctionWrapper extends SqlJsonQueryFunction {
private final SqlReturnTypeInference returnTypeInference;
SqlJsonQueryFunctionWrapper() {
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonValueFunctionWrapper.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonValueFunctionWrapper.java
index b28ef4786e4..03e60ecdf3c 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonValueFunctionWrapper.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/SqlJsonValueFunctionWrapper.java
@@ -35,7 +35,7 @@ import static
org.apache.flink.table.planner.plan.type.FlinkReturnTypes.VARCHAR_
* VARCHAR_FORCE_NULLABLE} return type inference by default. It also supports
specifying return type
* with the RETURNING keyword just like the original {@link
SqlJsonValueFunction}.
*/
-class SqlJsonValueFunctionWrapper extends SqlJsonValueFunction {
+public class SqlJsonValueFunctionWrapper extends SqlJsonValueFunction {
private final SqlReturnTypeInference returnTypeInference;
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecCalc.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecCalc.java
index cf389655031..e1ddbcbc46e 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecCalc.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecCalc.java
@@ -32,6 +32,7 @@ import
org.apache.flink.table.planner.plan.nodes.exec.InputProperty;
import
org.apache.flink.table.planner.plan.nodes.exec.SingleTransformationTranslator;
import org.apache.flink.table.planner.plan.nodes.exec.utils.ExecNodeUtil;
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil;
+import org.apache.flink.table.planner.utils.ShortcutUtils;
import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory;
import org.apache.flink.table.runtime.typeutils.InternalTypeInfo;
import org.apache.flink.table.types.logical.RowType;
@@ -99,10 +100,11 @@ public abstract class CommonExecCalc extends
ExecNodeBase<RowData>
final CodeGenOperatorFactory<RowData> substituteStreamOperator =
CalcCodeGenerator.generateCalcOperator(
ctx,
- inputTransform,
+ (RowType) inputEdge.getOutputType(),
(RowType) getOutputType(),
JavaScalaConversionUtil.toScala(projection),
JavaScalaConversionUtil.toScala(Optional.ofNullable(this.condition)),
+ ShortcutUtils.unwrapTypeFactory(planner),
retainHeader,
getClass().getSimpleName());
return ExecNodeUtil.createOneInputTransformation(
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java
index 05b07458a10..31cb09a545f 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/common/CommonExecLookupJoin.java
@@ -502,7 +502,8 @@ public abstract class CommonExecLookupJoin extends
ExecNodeBase<RowData> {
JavaScalaConversionUtil.toScala(projectionOnTemporalTable),
filterOnTemporalTable,
projectionOutputRelDataType,
- tableSourceRowType);
+ tableSourceRowType,
+ ShortcutUtils.unwrapTypeFactory(relBuilder));
asyncFunc =
new AsyncLookupJoinWithCalcRunner(
generatedFuncWithType.tableFunc(),
@@ -647,7 +648,8 @@ public abstract class CommonExecLookupJoin extends
ExecNodeBase<RowData> {
JavaScalaConversionUtil.toScala(projectionOnTemporalTable),
filterOnTemporalTable,
projectionOutputRelDataType,
- tableSourceRowType);
+ tableSourceRowType,
+ ShortcutUtils.unwrapTypeFactory(relBuilder));
processFunc =
new LookupJoinWithCalcRunner(
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/spec/DeltaJoinTree.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/spec/DeltaJoinTree.java
index f7555e06822..c826fcd1ff1 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/spec/DeltaJoinTree.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/spec/DeltaJoinTree.java
@@ -228,7 +228,8 @@ public class DeltaJoinTree {
node.filter,
rowTypePassThroughCalc,
rowTypeBeforeCalc,
- generatedCalcName))
+ generatedCalcName,
+ typeFactory))
.orElse(null);
if (node instanceof BinaryInputNode) {
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java
index c540c543a6d..c1b6ffcd179 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/plan/nodes/exec/stream/StreamExecDeltaJoin.java
@@ -574,7 +574,8 @@ public class StreamExecDeltaJoin extends
ExecNodeBase<RowData>
JavaScalaConversionUtil.toScala(projectionOnTemporalTable),
filterOnTemporalTable,
lookupSidePassThroughCalcRowType,
- lookupTableSourceRowType);
+ lookupTableSourceRowType,
+ typeFactory);
}
Preconditions.checkState(!generatedFetcherCollector.containsKey(lookupTableOrdinal));
diff --git
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/utils/ShortcutUtils.java
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/utils/ShortcutUtils.java
index 415b78efeac..b079eddf557 100644
---
a/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/utils/ShortcutUtils.java
+++
b/flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/utils/ShortcutUtils.java
@@ -40,13 +40,20 @@ import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexFieldAccess;
+import org.apache.calcite.rex.RexLocalRef;
import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.tools.RelBuilder;
import javax.annotation.Nullable;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
/**
* Utilities for quick access of commonly used instances (like {@link
FlinkTypeFactory}) without
* long chains of getters or casting like {@code (FlinkTypeFactory)
@@ -147,14 +154,15 @@ public final class ShortcutUtils {
return null;
}
final RexCall call = (RexCall) rexNode;
- if (!(call.getOperator() instanceof BridgingSqlFunction)) {
+ final SqlOperator operator = call.getOperator();
+ if (!(operator instanceof BridgingSqlFunction)) {
// legacy
- if (call.getOperator() instanceof TableSqlFunction) {
- return ((TableSqlFunction) call.getOperator()).udtf();
+ if (operator instanceof TableSqlFunction) {
+ return ((TableSqlFunction) operator).udtf();
}
return null;
}
- return ((BridgingSqlFunction) call.getOperator()).getDefinition();
+ return ((BridgingSqlFunction) operator).getDefinition();
}
public static @Nullable FunctionDefinition
unwrapFunctionDefinition(SqlOperator operator) {
@@ -169,6 +177,41 @@ public final class ShortcutUtils {
return functionDefinition != null && functionDefinition.getKind() ==
kind;
}
+ public static boolean isDeterministicThroughProgram(
+ RexNode node, @Nullable List<RexNode> exprs) {
+ if (exprs == null) {
+ return RexUtil.isDeterministic(node);
+ }
+ return isDeterministicThroughProgram(node, exprs, new HashSet<>());
+ }
+
+ private static boolean isDeterministicThroughProgram(
+ RexNode node, List<RexNode> exprs, Set<Integer> visited) {
+ if (node instanceof RexCall) {
+ final RexCall call = (RexCall) node;
+ if (!call.getOperator().isDeterministic()) {
+ return false;
+ }
+ for (RexNode operand : call.getOperands()) {
+ if (!isDeterministicThroughProgram(operand, exprs, visited)) {
+ return false;
+ }
+ }
+ return true;
+ }
+ if (node instanceof RexLocalRef) {
+ final int idx = ((RexLocalRef) node).getIndex();
+ // already on the stack: skip rather than recurse forever
+ return !visited.add(idx)
+ || isDeterministicThroughProgram(exprs.get(idx), exprs,
visited);
+ }
+ if (node instanceof RexFieldAccess) {
+ return isDeterministicThroughProgram(
+ ((RexFieldAccess) node).getReferenceExpr(), exprs,
visited);
+ }
+ return true;
+ }
+
public static @Nullable BridgingSqlFunction
unwrapBridgingSqlFunction(RexCall call) {
final SqlOperator operator = call.getOperator();
if (operator instanceof BridgingSqlFunction) {
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CalcCodeGenerator.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CalcCodeGenerator.scala
index 8072a9ca427..966aca0abec 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CalcCodeGenerator.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CalcCodeGenerator.scala
@@ -18,32 +18,31 @@
package org.apache.flink.table.planner.codegen
import org.apache.flink.api.common.functions.{FlatMapFunction, Function}
-import org.apache.flink.api.dag.Transformation
import org.apache.flink.configuration.ReadableConfig
import org.apache.flink.table.api.{TableException, ValidationException}
import org.apache.flink.table.data.{BoxedWrapperRowData, RowData}
import org.apache.flink.table.functions.FunctionKind
+import org.apache.flink.table.planner.calcite.{FlinkRexBuilder,
FlinkTypeFactory}
import org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction
import org.apache.flink.table.runtime.generated.GeneratedFunction
import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory
-import org.apache.flink.table.runtime.typeutils.InternalTypeInfo
import org.apache.flink.table.types.logical.RowType
import org.apache.calcite.rex._
+import scala.collection.JavaConverters._
+
object CalcCodeGenerator {
def generateCalcOperator(
ctx: CodeGeneratorContext,
- inputTransform: Transformation[RowData],
+ inputType: RowType,
outputType: RowType,
projection: Seq[RexNode],
condition: Option[RexNode],
+ typeFactory: FlinkTypeFactory,
retainHeader: Boolean = false,
opName: String): CodeGenOperatorFactory[RowData] = {
- val inputType = inputTransform.getOutputType
- .asInstanceOf[InternalTypeInfo[RowData]]
- .toRowType
// filter out time attributes
val inputTerm = CodeGenUtils.DEFAULT_INPUT1_TERM
val processCode = generateProcessCode(
@@ -53,8 +52,13 @@ object CalcCodeGenerator {
classOf[BoxedWrapperRowData],
projection,
condition,
+ typeFactory,
+ inputTerm,
+ CodeGenUtils.DEFAULT_OPERATOR_COLLECTOR_TERM,
eagerInputUnboxingCode = true,
- retainHeader = retainHeader)
+ retainHeader = retainHeader,
+ outputDirectly = false
+ )
val genOperator =
OperatorCodeGenerator.generateOneInputStreamOperator[RowData, RowData](
@@ -76,7 +80,8 @@ object CalcCodeGenerator {
calcProjection: Seq[RexNode],
calcCondition: Option[RexNode],
tableConfig: ReadableConfig,
- classLoader: ClassLoader): GeneratedFunction[FlatMapFunction[RowData,
RowData]] = {
+ classLoader: ClassLoader,
+ typeFactory: FlinkTypeFactory):
GeneratedFunction[FlatMapFunction[RowData, RowData]] = {
val ctx = new CodeGeneratorContext(tableConfig, classLoader)
val inputTerm = CodeGenUtils.DEFAULT_INPUT1_TERM
val collectorTerm = CodeGenUtils.DEFAULT_COLLECTOR_TERM
@@ -87,6 +92,8 @@ object CalcCodeGenerator {
outRowClass,
calcProjection,
calcCondition,
+ typeFactory,
+ inputTerm,
collectorTerm = collectorTerm,
eagerInputUnboxingCode = false,
outputDirectly = true
@@ -110,6 +117,7 @@ object CalcCodeGenerator {
outRowClass: Class[_ <: RowData],
projection: Seq[RexNode],
condition: Option[RexNode],
+ typeFactory: FlinkTypeFactory,
inputTerm: String = CodeGenUtils.DEFAULT_INPUT1_TERM,
collectorTerm: String = CodeGenUtils.DEFAULT_OPERATOR_COLLECTOR_TERM,
eagerInputUnboxingCode: Boolean,
@@ -121,7 +129,9 @@ object CalcCodeGenerator {
projection.foreach(_.accept(ScalarFunctionsValidator))
condition.foreach(_.accept(ScalarFunctionsValidator))
- val exprGenerator = new ExprCodeGenerator(ctx, false)
+ val rexProgram = buildRexProgram(typeFactory, inputType, projection,
condition)
+
+ val exprGenerator = new ExprCodeGenerator(ctx, false, rexProgram)
.bindInput(inputType, inputTerm = inputTerm)
val onlyFilter = projection.lengthCompare(inputType.getFieldCount) == 0 &&
@@ -137,6 +147,8 @@ object CalcCodeGenerator {
}
def produceProjectionCode: String = {
+ val projection = rexProgram.getProjectList.asScala
+
val projectionExprs = projection.map(exprGenerator.generateExpression)
val projectionExpression =
exprGenerator.generateResultExpression(projectionExprs, outRowType,
outRowClass)
@@ -162,16 +174,20 @@ object CalcCodeGenerator {
"It should be removed by CalcRemoveRule.")
} else if (condition.isEmpty) { // only projection
val projectionCode = produceProjectionCode
+ val localRefCode = ctx.reuseLocalRefCode()
s"""
|${if (eagerInputUnboxingCode) ctx.reuseInputUnboxingCode() else ""}
+ |$localRefCode
|$projectionCode
|""".stripMargin
} else {
- val filterCondition = exprGenerator.generateExpression(condition.get)
+ val filterCondition =
exprGenerator.generateExpression(rexProgram.getCondition)
// only filter
if (onlyFilter) {
+ val localRefCode = ctx.reuseLocalRefCode()
s"""
|${if (eagerInputUnboxingCode) ctx.reuseInputUnboxingCode() else ""}
+ |$localRefCode
|${filterCondition.code}
|if (${filterCondition.resultTerm}) {
| ${produceOutputCode(inputTerm)}
@@ -181,19 +197,35 @@ object CalcCodeGenerator {
val filterInputCode = ctx.reuseInputUnboxingCode()
val filterInputSet = Set(ctx.reusableInputUnboxingExprs.keySet.toSeq:
_*)
+ val filterLocalRefSet: Set[Int] =
ctx.getReusableLocalRefExprBottomScope.keySet.toSet
+
// if any filter conditions, projection code will enter an new scope
val projectionCode = produceProjectionCode
val projectionInputCode = ctx.reusableInputUnboxingExprs
- .filter(entry => !filterInputSet.contains(entry._1))
+ .filter { case (k, _) => !filterInputSet.contains(k) }
+ .values
+ .map(_.code)
+ .mkString("\n")
+
+ val filterLocalRefCode = ctx.getReusableLocalRefExprBottomScope
+ .filter { case (k, _) => filterLocalRefSet.contains(k) }
.values
.map(_.code)
.mkString("\n")
+ val projectionLocalRefCode = ctx.getReusableLocalRefExprBottomScope
+ .filter { case (k, _) => !filterLocalRefSet.contains(k) }
+ .values
+ .map(_.code)
+ .mkString("\n")
+
s"""
|${if (eagerInputUnboxingCode) filterInputCode else ""}
+ |$filterLocalRefCode
|${filterCondition.code}
|if (${filterCondition.resultTerm}) {
- | ${if (eagerInputUnboxingCode) projectionInputCode else ""}
+ | ${if (eagerInputUnboxingCode) projectionInputCode else ""}
+ | $projectionLocalRefCode
| $projectionCode
|}
|""".stripMargin
@@ -201,6 +233,22 @@ object CalcCodeGenerator {
}
}
+ private def buildRexProgram(
+ typeFactory: FlinkTypeFactory,
+ inputType: RowType,
+ projection: Seq[RexNode],
+ condition: Option[RexNode]
+ ): RexProgram = {
+ val rexBuilder = new FlinkRexBuilder(typeFactory)
+ val relInputType = typeFactory.createFieldTypeFromLogicalType(inputType)
+ val builder = new RexProgramBuilder(relInputType, rexBuilder)
+ projection.foreach(p => builder.addProject(p, null))
+ if (condition.isDefined) {
+ builder.addCondition(condition.get)
+ }
+ builder.getProgram
+ }
+
private object ScalarFunctionsValidator extends RexVisitorImpl[Unit](true) {
override def visitCall(call: RexCall): Unit = {
super.visitCall(call)
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala
index 811fd1a8420..ff924fe0f30 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGenUtils.scala
@@ -46,6 +46,8 @@ import org.apache.flink.types.{ColumnList, Row, RowKind}
import org.apache.flink.types.bitmap.Bitmap
import org.apache.flink.types.variant.Variant
+import org.apache.calcite.rex.{RexNode, RexProgram}
+
import java.lang.{Boolean => JBoolean, Byte => JByte, Double => JDouble, Float
=> JFloat, Integer => JInt, Long => JLong, Object => JObject, Short => JShort}
import java.lang.reflect.Method
import java.util.concurrent.atomic.AtomicLong
@@ -1120,4 +1122,12 @@ object CodeGenUtils {
GenerateUtils.generateFieldAccess(ctx, inputType, inputTerm, index)
}
}
+
+ def getExprsFromProgramOrNull(rexProgram: RexProgram):
java.util.List[RexNode] = {
+ if (rexProgram == null) {
+ null
+ } else {
+ rexProgram.getExprList
+ }
+ }
}
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala
index 02706cc1309..fade7279606 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/CodeGeneratorContext.scala
@@ -17,7 +17,6 @@
*/
package org.apache.flink.table.planner.codegen
-import org.apache.flink.api.common.functions.Function
import org.apache.flink.api.common.typeutils.TypeSerializer
import org.apache.flink.configuration.ReadableConfig
import org.apache.flink.table.data.GenericRowData
@@ -116,6 +115,36 @@ class CodeGeneratorContext(
val reusableInputUnboxingExprs: mutable.Map[(String, Int),
GeneratedExpression] =
mutable.Map[(String, Int), GeneratedExpression]()
+ // Stack of RexLocalRef cache scopes (`exprList-index -> generated body`).
+ // * Bottom scope == getReusableLocalRefExprBottomScope: bodies are
hoisted to the top of the method
+ // and run unconditionally for every row.
+ // * Inner scopes (push/popLocalRefScope): bodies are folded into a single
guarded
+ // operand's code by ExprCodeGenerator.visitOperandInScopedCache and run
only when
+ // the guard fires. Inserts always target the innermost scope; lookup
walks innermost-out.
+ //
+ // Example — `CASE WHEN b <> 0 THEN a / b ELSE NULL`:
+ //
+ // With scoping (correct):
+ // boolean cmp = b != 0;
+ // if (cmp) {
+ // int div = a / b; // emitted inside the guarded scope
+ // result = div;
+ // } else {
+ // result = null;
+ // }
+ //
+ // Without scoping (buggy):
+ // int div = a / b; // throws ArithmeticException when b == 0
+ // boolean cmp = b != 0;
+ // if (cmp) { result = div; }
+ // else { result = null; }
+ //
+ // The set of operand positions that get scoped lives in
+ // ExprCodeGenerator.conditionalOperandIndices — extend it when adding new
short-circuit
+ // operators.
+ private val localRefScopes =
+ mutable.ArrayBuffer(mutable.LinkedHashMap.empty[Int, GeneratedExpression])
+
// set of constructor statements that will be added only once
// we use a LinkedHashSet to keep the insertion order
private val reusableConstructorStatements: mutable.LinkedHashSet[(String,
String)] =
@@ -783,6 +812,7 @@ class CodeGeneratorContext(
/**
* Adds a reusable Object to the member area of the generated class
+ *
* @param obj
* the object to be added to the generated class
* @param fieldNamePrefix
@@ -1075,4 +1105,61 @@ class CodeGeneratorContext(
fieldTerm
}
+
+ //
---------------------------------------------------------------------------------
+ // Reusable local ref code with scope
+ //
---------------------------------------------------------------------------------
+
+ // Bottom scope of localRefScopes: holds unconditionally evaluated local
refs.
+ def getReusableLocalRefExprBottomScope: mutable.LinkedHashMap[Int,
GeneratedExpression] =
+ localRefScopes(0)
+
+ /**
+ * Adds a reusable [[org.apache.calcite.rex.RexLocalRef]] expression keyed
by its index in the
+ * program's exprList. The expression is stored in the innermost active
scope.
+ */
+ def addReusableLocalRefExpr(index: Int, expr: GeneratedExpression): Unit =
+ localRefScopes.last(index) = expr
+
+ /**
+ * Looks up a previously cached [[org.apache.calcite.rex.RexLocalRef]]
expression by its exprList
+ * index. Scopes are searched innermost-out so that a body cached inside a
guarded scope takes
+ * precedence over an outer entry.
+ */
+ def getReusableLocalRefExpr(index: Int): Option[GeneratedExpression] = {
+ // Search innermost-out: a body cached in an inner (guarded) scope wins
over outer
+ // entries. In practice the cache is monotone — an entry never appears in
two scopes
+ // simultaneously.
+ var i = localRefScopes.size - 1
+ while (i >= 0) {
+ val maybe = localRefScopes(i).get(index)
+ if (maybe.isDefined) return maybe
+ i -= 1
+ }
+ None
+ }
+
+ /**
+ * Returns the generated code for all unconditionally-evaluated local-ref
expressions (bottom
+ * scope), concatenated in insertion order.
+ */
+ def reuseLocalRefCode(): String = {
+ getReusableLocalRefExprBottomScope.values.map(_.code).mkString("\n")
+ }
+
+ /** Pushes a new, empty local-ref cache scope onto the scope stack. */
+ def pushLocalRefScope(): Unit = {
+ localRefScopes.append(mutable.LinkedHashMap.empty)
+ }
+
+ /**
+ * Pops the innermost local-ref cache scope and returns its entries. The
bottom scope
+ * ([[getReusableLocalRefExprBottomScope]]) cannot be popped.
+ */
+ def popLocalRefScope(): scala.collection.Map[Int, GeneratedExpression] = {
+ require(
+ localRefScopes.size > 1,
+ "Cannot pop the bottom RexLocalRef cache scope (reusableLocalRefExprs).")
+ localRefScopes.remove(localRefScopes.size - 1)
+ }
}
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala
index 7154aa09f0a..1de3ca9bee8 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExprCodeGenerator.scala
@@ -37,7 +37,8 @@ import
org.apache.flink.table.planner.functions.bridging.BridgingSqlFunction
import org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable._
import org.apache.flink.table.planner.functions.sql.SqlThrowExceptionFunction
import org.apache.flink.table.planner.functions.utils.{ScalarSqlFunction,
TableSqlFunction}
-import org.apache.flink.table.planner.plan.utils.RexLiteralUtil
+import org.apache.flink.table.planner.plan.utils.{FlinkRexUtil, RexLiteralUtil}
+import org.apache.flink.table.planner.utils.ShortcutUtils
import
org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromLogicalTypeToDataType
import org.apache.flink.table.runtime.types.PlannerTypeUtils.isInteroperable
import org.apache.flink.table.runtime.typeutils.TypeCheckUtils
@@ -55,9 +56,15 @@ import scala.collection.JavaConversions._
* This code generator is mainly responsible for generating codes for a given
calcite [[RexNode]].
* It can also generate type conversion codes for the result converter.
*/
-class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean)
+class ExprCodeGenerator(
+ ctx: CodeGeneratorContext,
+ nullableInput: Boolean,
+ val rexProgram: RexProgram)
extends RexVisitor[GeneratedExpression] {
+ def this(ctx: CodeGeneratorContext, nullableInput: Boolean) =
+ this(ctx, nullableInput, null)
+
/** term of the [[ProcessFunction]]'s context, can be changed when needed */
var contextTerm = "ctx"
@@ -344,7 +351,6 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext,
nullableInput: Boolean)
}
override def visitInputRef(inputRef: RexInputRef): GeneratedExpression = {
- // for specific custom code generation
if (input1Type == null) {
return GeneratedExpression(
inputRef.getName,
@@ -416,8 +422,53 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext,
nullableInput: Boolean)
GeneratedExpression(input1Term, NEVER_NULL, NO_CODE, input1Type)
}
- override def visitLocalRef(localRef: RexLocalRef): GeneratedExpression =
- throw new CodeGenException("RexLocalRef are not supported yet.")
+ override def visitLocalRef(localRef: RexLocalRef): GeneratedExpression = {
+ // addReusableLocalVariable
+ // for specific custom code generation
+ if (input1Type == null) {
+ return GeneratedExpression(
+ localRef.getName,
+ localRef.getName + "IsNull",
+ NO_CODE,
+ FlinkTypeFactory.toLogicalType(localRef.getType))
+ }
+ // for the general cases with a previous call to bindInput()
+ val input1Arity = input1Type match {
+ case r: RowType => r.getFieldCount
+ case _ => 1
+ }
+ if (localRef.getIndex >= input1Arity) {
+ if (rexProgram == null) {
+ throw new CodeGenException(s"RexLocalRef(${localRef.getIndex})
requires a RexProgram.")
+ }
+ val idx = localRef.getIndex
+ val target = rexProgram.getExprList.get(idx)
+ if (!isDeterministicThroughProgram(target)) {
+ return target.accept(this)
+ }
+ val full = ctx.getReusableLocalRefExpr(idx) match {
+ case Some(cached) => cached
+ case None =>
+ val expr = target.accept(this)
+ ctx.addReusableLocalRefExpr(idx, expr)
+ expr
+ }
+ return GeneratedExpression(
+ full.resultTerm,
+ full.nullTerm,
+ NO_CODE,
+ full.resultType,
+ full.literalValue)
+ }
+
+ generateInputAccess(
+ ctx,
+ input1Type,
+ input1Term,
+ localRef.getIndex,
+ nullableInput,
+ deepCopy = true)
+ }
def visitRexFieldVariable(variable: RexFieldVariable): GeneratedExpression =
{
val internalType = FlinkTypeFactory.toLogicalType(variable.dataType)
@@ -462,7 +513,7 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext,
nullableInput: Boolean)
val resultType = FlinkTypeFactory.toLogicalType(call.getType)
// throw exception if json function is called outside JSON_OBJECT or
JSON_ARRAY function
- if (isJsonFunctionOperand(call)) {
+ if (isJsonFunctionOperand(call,
CodeGenUtils.getExprsFromProgramOrNull(rexProgram))) {
throw new ValidationException(
"The JSON() function is currently only supported inside JSON_ARRAY()
or as the VALUE param" +
" of JSON_OBJECT(). Example: JSON_OBJECT('a', JSON('{\"key\":
\"value\"}')) or " +
@@ -473,10 +524,12 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext,
nullableInput: Boolean)
return generateSearch(
ctx,
generateExpression(call.getOperands.get(0)),
- call.getOperands.get(1).asInstanceOf[RexLiteral])
+ rexProgram,
+ call.getOperands)
}
// convert operands and help giving untyped NULL literals a type
+ val condIdxs = conditionalOperandIndices(call)
val operands = call.getOperands.zipWithIndex.map {
// this helps e.g. for AS(null)
@@ -487,15 +540,55 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext,
nullableInput: Boolean)
generateNullLiteral(resultType)
// We only support the JSON function inside of JSON_OBJECT or JSON_ARRAY
- case (operand: RexNode, i) if isSupportedJsonOperand(operand, call, i) =>
+ case (operand: RexNode, i)
+ if isSupportedJsonOperand(
+ operand,
+ call,
+ i,
+ CodeGenUtils.getExprsFromProgramOrNull(rexProgram)) =>
generateJsonCall(operand)
+ case (o @ _, i) if condIdxs.contains(i) => visitOperandInScopedCache(o)
+
case (o @ _, _) => o.accept(this)
}
generateCallExpression(ctx, call, operands, resultType)
}
+ /**
+ * Indices of `call`'s operands that are NOT unconditionally evaluated at
runtime. Used to scope
+ * the RexLocalRef cache so that bodies cached while visiting these operands
are not hoisted out
+ * of the surrounding short-circuit / if-block.
+ *
+ * - `CASE(when_1, then_1, when_2, then_2, ..., else)`: only `when_1` is
unconditional.
+ * - `AND(a_0, a_1, ..., a_n)` / `OR(...)`: only `a_0` is unconditional;
subsequent operands are
+ * short-circuited by the operator semantics and the codegen.
+ */
+ private def conditionalOperandIndices(call: RexCall): Set[Int] =
call.getKind match {
+ case SqlKind.CASE | SqlKind.AND | SqlKind.OR | SqlKind.COALESCE =>
+ (1 until call.getOperands.size).toSet
+ case _ => Set.empty
+ }
+
+ private def visitOperandInScopedCache(operand: RexNode): GeneratedExpression
= {
+ ctx.pushLocalRefScope()
+ val (operandExpr, scopedBodies) = {
+ val expr = operand.accept(this)
+ val popped = ctx.popLocalRefScope()
+ (expr, popped.values.map(_.code).mkString("\n"))
+ }
+ if (scopedBodies.isEmpty) {
+ operandExpr
+ } else
+ GeneratedExpression(
+ operandExpr.resultTerm,
+ operandExpr.nullTerm,
+ scopedBodies + "\n" + operandExpr.code,
+ operandExpr.resultType,
+ operandExpr.literalValue)
+ }
+
override def visitOver(over: RexOver): GeneratedExpression =
throw new CodeGenException("Aggregate functions over windows are not
supported yet.")
@@ -786,9 +879,11 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext,
nullableInput: Boolean)
case JSON_QUERY => new JsonQueryCallGen().generate(ctx, operands,
resultType)
- case JSON_OBJECT => new JsonObjectCallGen(call).generate(ctx, operands,
resultType)
+ case JSON_OBJECT =>
+ new JsonObjectCallGen(call, rexProgram).generate(ctx, operands,
resultType)
- case JSON_ARRAY => new JsonArrayCallGen(call).generate(ctx, operands,
resultType)
+ case JSON_ARRAY =>
+ new JsonArrayCallGen(call, rexProgram).generate(ctx, operands,
resultType)
case _: SqlThrowExceptionFunction =>
val nullValue = generateNullLiteral(resultType)
@@ -827,7 +922,7 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext,
nullableInput: Boolean)
generateGreatestLeast(ctx, resultType, operands, greatest = false)
case BuiltInFunctionDefinitions.JSON_STRING =>
- new JsonStringCallGen(call).generate(ctx, operands, resultType)
+ new JsonStringCallGen(call, rexProgram).generate(ctx, operands,
resultType)
case BuiltInFunctionDefinitions.INTERNAL_HASHCODE =>
new HashCodeCallGen().generate(ctx, operands, resultType)
@@ -847,7 +942,7 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext,
nullableInput: Boolean)
new JsonCallGen().generate(ctx, operands,
FlinkTypeFactory.toLogicalType(call.getType))
case _ =>
- new BridgingSqlFunctionCallGen(call).generate(ctx, operands,
resultType)
+ new BridgingSqlFunctionCallGen(call, rexProgram).generate(ctx,
operands, resultType)
}
// advanced scalar functions
@@ -875,7 +970,14 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext,
nullableInput: Boolean)
}
private def generateJsonCall(operand: RexNode) = {
- val jsonCall = operand.asInstanceOf[RexCall]
+ // After unification of projections + condition into a single RexProgram,
structurally
+ // identical sub-expressions are collapsed into one exprList entry
referenced via
+ // RexLocalRef. JSON_OBJECT/JSON_ARRAY operands recognised as JSON via
+ // isSupportedJsonOperand may therefore arrive here as a RexLocalRef;
resolve it back to
+ // the underlying RexCall before casting.
+ val jsonCall = FlinkRexUtil
+ .expandLocalRef(operand,
CodeGenUtils.getExprsFromProgramOrNull(rexProgram))
+ .asInstanceOf[RexCall]
val jsonOperands = jsonCall.getOperands.map(_.accept(this))
generateCallExpression(
ctx,
@@ -896,4 +998,9 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext,
nullableInput: Boolean)
}
}.toArray
}
+
+ private def isDeterministicThroughProgram(node: RexNode): Boolean =
+ ShortcutUtils.isDeterministicThroughProgram(
+ node,
+ CodeGenUtils.getExprsFromProgramOrNull(rexProgram))
}
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExpressionReducer.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExpressionReducer.scala
index 567554e6e1c..b1c70fb9c52 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExpressionReducer.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExpressionReducer.scala
@@ -139,8 +139,8 @@ class ExpressionReducer(
} else
unreduced match {
case call: RexCall
- if (nonReducibleJsonFunctions.contains(call.getOperator) ||
isJsonFunctionOperand(
- call)) =>
+ if (nonReducibleJsonFunctions.contains(call.getOperator)
+ || isJsonFunctionOperand(call, null)) =>
reducedValues.add(unreduced)
case _ =>
unreduced.getType.getSqlTypeName match {
@@ -297,7 +297,10 @@ class ExpressionReducer(
}
// Exclude some JSON functions which behave differently
// when called as an argument of another call of one of these
functions.
- if (nonReducibleJsonFunctions.contains(call.getOperator) ||
isJsonFunctionOperand(call)) {
+ if (
+ nonReducibleJsonFunctions.contains(call.getOperator)
+ || isJsonFunctionOperand(call, null)
+ ) {
None
} else {
Some(call)
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/FunctionCodeGenerator.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/FunctionCodeGenerator.scala
index 817d399bacc..3131de4f8f9 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/FunctionCodeGenerator.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/FunctionCodeGenerator.scala
@@ -18,12 +18,11 @@
package org.apache.flink.table.planner.codegen
import org.apache.flink.api.common.functions._
-import org.apache.flink.configuration.Configuration
import org.apache.flink.streaming.api.functions.ProcessFunction
import org.apache.flink.streaming.api.functions.async.{AsyncFunction,
RichAsyncFunction}
import org.apache.flink.table.planner.codegen.CodeGenUtils._
import org.apache.flink.table.planner.codegen.Indenter.toISC
-import org.apache.flink.table.runtime.generated.{FilterCondition,
GeneratedFilterCondition, GeneratedFunction, GeneratedJoinCondition,
JoinCondition}
+import org.apache.flink.table.runtime.generated._
import org.apache.flink.table.types.logical.LogicalType
/**
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/JsonGenerateUtils.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/JsonGenerateUtils.scala
index 3fd256e071d..a64a97727cc 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/JsonGenerateUtils.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/JsonGenerateUtils.scala
@@ -21,11 +21,12 @@ import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode
import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.node.{ArrayNode,
ObjectNode}
import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.util.RawValue
import org.apache.flink.table.api.{DataTypes, JsonOnNull}
-import org.apache.flink.table.functions.BuiltInFunctionDefinitions.JSON
+import org.apache.flink.table.functions.{BuiltInFunctionDefinitions,
FunctionDefinition}
import org.apache.flink.table.planner.codegen.CodeGenUtils._
-import
org.apache.flink.table.planner.functions.sql.FlinkSqlOperatorTable.{JSON_ARRAY,
JSON_OBJECT}
+import
org.apache.flink.table.planner.functions.sql.{SqlJsonArrayFunctionWrapper,
SqlJsonObjectFunctionWrapper, SqlJsonQueryFunctionWrapper,
SqlJsonValueFunctionWrapper}
+import org.apache.flink.table.planner.plan.utils.FlinkRexUtil
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala
-import
org.apache.flink.table.planner.utils.ShortcutUtils.unwrapFunctionDefinition
+import org.apache.flink.table.planner.utils.ShortcutUtils
import org.apache.flink.table.runtime.functions.SqlJsonUtils
import
org.apache.flink.table.runtime.typeutils.TypeCheckUtils.isCharacterString
import org.apache.flink.table.types.logical._
@@ -51,8 +52,9 @@ object JsonGenerateUtils {
def createNodeTerm(
ctx: CodeGeneratorContext,
expression: GeneratedExpression,
- operand: RexNode): String = {
- if (isJsonObjectOrArrayOperand(operand) || isJsonFunctionOperand(operand))
{
+ operand: RexNode,
+ exprs: java.util.List[RexNode]): String = {
+ if (isJsonObjectOrArrayOperand(operand, exprs) ||
isJsonFunctionOperand(operand, exprs)) {
createRawNodeTerm(expression)
} else {
createNodeTerm(ctx, expression)
@@ -177,59 +179,36 @@ object JsonGenerateUtils {
}
}
- /** Determines whether the given operand is a call to a JSON_OBJECT */
- def isJsonObjectOperand(operand: RexNode): Boolean = {
- operand match {
- case rexCall: RexCall =>
- rexCall.getOperator match {
- case JSON_OBJECT => true
- case _ => false
- }
- case _ => false
- }
- }
+ /** Determines whether the given operand is a call to a JSON_OBJECT. */
+ def isJsonObjectOperand(operand: RexNode, localRefs:
java.util.List[RexNode]): Boolean =
+ isOneOfFunctionDefinitions(
+ FlinkRexUtil.expandLocalRef(operand, localRefs),
+ BuiltInFunctionDefinitions.JSON_OBJECT)
- /** Determines whether the given operand is a call to a JSON_ARRAY */
- def isJsonArrayOperand(operand: RexNode): Boolean = {
- operand match {
- case rexCall: RexCall =>
- rexCall.getOperator match {
- case JSON_ARRAY => true
- case _ => false
- }
- case _ => false
- }
- }
+ /** Determines whether the given operand is a call to a JSON_ARRAY. */
+ def isJsonArrayOperand(operand: RexNode, localRefs:
java.util.List[RexNode]): Boolean =
+ isOneOfFunctionDefinitions(
+ FlinkRexUtil.expandLocalRef(operand, localRefs),
+ BuiltInFunctionDefinitions.JSON_ARRAY)
/**
* Determines whether the given operand is a call to a JSON_OBJECT or
JSON_ARRAY whose result
* should be inserted as a raw value instead of as a character string.
*/
- def isJsonObjectOrArrayOperand(operand: RexNode): Boolean = {
- operand match {
- case rexCall: RexCall =>
- rexCall.getOperator match {
- case JSON_OBJECT | JSON_ARRAY => true
- case _ => false
- }
- case _ => false
- }
- }
+ def isJsonObjectOrArrayOperand(operand: RexNode, localRefs:
java.util.List[RexNode]): Boolean =
+ isOneOfFunctionDefinitions(
+ FlinkRexUtil.expandLocalRef(operand, localRefs),
+ BuiltInFunctionDefinitions.JSON_OBJECT,
+ BuiltInFunctionDefinitions.JSON_ARRAY)
/**
* Determines whether the given operand is a call to JSON function whose
call currently just
- * passes through the input value as output value
+ * passes through the input value as output value.
*/
- def isJsonFunctionOperand(operand: RexNode): Boolean = {
- operand match {
- case rexCall: RexCall =>
- unwrapFunctionDefinition(rexCall) match {
- case JSON => true
- case _ => false
- }
- case _ => false
- }
- }
+ def isJsonFunctionOperand(operand: RexNode, localRefs:
java.util.List[RexNode]): Boolean =
+ isOneOfFunctionDefinitions(
+ FlinkRexUtil.expandLocalRef(operand, localRefs),
+ BuiltInFunctionDefinitions.JSON)
/**
* Determines whether a JSON function is allowed in the current context.
JSON functions are
@@ -237,9 +216,13 @@ object JsonGenerateUtils {
* of a JSON_OBJECT call, we do (i % 2) == 0 to check if it's being used in
second parameter, the
* values' parameter.
*/
- def isSupportedJsonOperand(operand: RexNode, call: RexNode, i: Int): Boolean
= {
- isJsonFunctionOperand(operand) &&
- (isJsonArrayOperand(call) || isJsonObjectOperand(call) && (i % 2) == 0)
+ def isSupportedJsonOperand(
+ operand: RexNode,
+ call: RexNode,
+ i: Int,
+ localRefs: java.util.List[RexNode]): Boolean = {
+ isJsonFunctionOperand(operand, localRefs) &&
+ (isJsonArrayOperand(call, localRefs) || isJsonObjectOperand(call,
localRefs) && (i % 2) == 0)
}
/** Generates a method to convert arrays into [[ArrayNode]]. */
@@ -331,4 +314,23 @@ object JsonGenerateUtils {
ctx.addReusableMember(methodCode)
methodName
}
+
+ def isOneOfFunctionDefinitions(
+ rexNode: RexNode,
+ expectedDefinitions: FunctionDefinition*): Boolean = {
+ if (!rexNode.isInstanceOf[RexCall]) return false
+ val call = rexNode.asInstanceOf[RexCall]
+ val unwrapped = ShortcutUtils.unwrapFunctionDefinition(call) match {
+ case d if d != null => d
+ case _ =>
+ call.getOperator match {
+ case _: SqlJsonArrayFunctionWrapper =>
BuiltInFunctionDefinitions.JSON_ARRAY
+ case _: SqlJsonObjectFunctionWrapper =>
BuiltInFunctionDefinitions.JSON_OBJECT
+ case _: SqlJsonQueryFunctionWrapper =>
BuiltInFunctionDefinitions.JSON_QUERY
+ case _: SqlJsonValueFunctionWrapper =>
BuiltInFunctionDefinitions.JSON_VALUE
+ case _ => return false
+ }
+ }
+ expectedDefinitions.exists(_ eq unwrapped)
+ }
}
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LongHashJoinGenerator.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LongHashJoinGenerator.scala
index 4256c90b596..bae94282d1d 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LongHashJoinGenerator.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LongHashJoinGenerator.scala
@@ -18,7 +18,7 @@
package org.apache.flink.table.planner.codegen
import org.apache.flink.api.common.functions.DefaultOpenContext
-import org.apache.flink.configuration.{Configuration, ReadableConfig}
+import org.apache.flink.configuration.ReadableConfig
import org.apache.flink.metrics.Gauge
import org.apache.flink.table.data.{RowData, TimestampData}
import org.apache.flink.table.data.utils.JoinedRowData
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala
index db158572df9..e011dd2cfce 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/LookupJoinCodeGenerator.scala
@@ -354,14 +354,16 @@ object LookupJoinCodeGenerator {
projection: Seq[RexNode],
condition: RexNode,
outputType: RelDataType,
- tableSourceRowType: RowType): GeneratedFunction[FlatMapFunction[RowData,
RowData]] = {
+ tableSourceRowType: RowType,
+ typeFactory: FlinkTypeFactory):
GeneratedFunction[FlatMapFunction[RowData, RowData]] = {
generateCalcMapFunction(
tableConfig,
classLoader,
projection,
condition,
FlinkTypeFactory.toLogicalRowType(outputType),
- tableSourceRowType
+ tableSourceRowType,
+ typeFactory
)
}
@@ -375,7 +377,8 @@ object LookupJoinCodeGenerator {
projection: Seq[RexNode],
condition: RexNode,
outputType: RowType,
- tableSourceRowType: RowType): GeneratedFunction[FlatMapFunction[RowData,
RowData]] = {
+ tableSourceRowType: RowType,
+ typeFactory: FlinkTypeFactory):
GeneratedFunction[FlatMapFunction[RowData, RowData]] = {
generateCalcMapFunction(
tableConfig,
classLoader,
@@ -383,7 +386,8 @@ object LookupJoinCodeGenerator {
condition,
outputType,
tableSourceRowType,
- "TableCalcMapFunction")
+ "TableCalcMapFunction",
+ typeFactory)
}
/**
@@ -397,7 +401,8 @@ object LookupJoinCodeGenerator {
condition: RexNode,
outputType: RowType,
tableSourceRowType: RowType,
- name: String): GeneratedFunction[FlatMapFunction[RowData, RowData]] = {
+ name: String,
+ typeFactory: FlinkTypeFactory):
GeneratedFunction[FlatMapFunction[RowData, RowData]] = {
CalcCodeGenerator.generateFunction(
tableSourceRowType,
name,
@@ -406,7 +411,8 @@ object LookupJoinCodeGenerator {
projection,
Option(condition),
tableConfig,
- classLoader
+ classLoader,
+ typeFactory
)
}
}
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingFunctionGenUtil.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingFunctionGenUtil.scala
index cb241c8feb4..2115e76304d 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingFunctionGenUtil.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingFunctionGenUtil.scala
@@ -18,7 +18,7 @@
package org.apache.flink.table.planner.codegen.calls
import org.apache.flink.api.common.functions.{AbstractRichFunction,
OpenContext, RichFunction}
-import org.apache.flink.configuration.{Configuration, ReadableConfig}
+import org.apache.flink.configuration.ReadableConfig
import org.apache.flink.table.api.{DataTypes, TableException}
import org.apache.flink.table.api.Expressions.callSql
import org.apache.flink.table.data.{GenericRowData, RawValueData, StringData}
@@ -27,9 +27,10 @@ import
org.apache.flink.table.expressions.ApiExpressionUtils.{typeLiteral, unres
import org.apache.flink.table.expressions.Expression
import org.apache.flink.table.functions._
import
org.apache.flink.table.functions.SpecializedFunction.{ExpressionEvaluator,
ExpressionEvaluatorFactory}
-import
org.apache.flink.table.functions.UserDefinedFunctionHelper.{validateClassForRuntime,
ASYNC_SCALAR_EVAL, ASYNC_TABLE_EVAL, SCALAR_EVAL, TABLE_EVAL}
+import org.apache.flink.table.functions.UserDefinedFunctionHelper._
import org.apache.flink.table.planner.calcite.{FlinkTypeFactory, RexFactory}
import org.apache.flink.table.planner.codegen._
+import
org.apache.flink.table.planner.codegen.AsyncCodeGenerator.DEFAULT_DELEGATING_FUTURE_TERM
import org.apache.flink.table.planner.codegen.CodeGenUtils._
import org.apache.flink.table.planner.codegen.GeneratedExpression.{NEVER_NULL,
NO_CODE}
import org.apache.flink.table.planner.utils.JavaScalaConversionUtil.toScala
@@ -48,8 +49,6 @@ import org.apache.flink.table.types.utils.DataTypeUtils
import org.apache.flink.table.types.utils.DataTypeUtils.{isInternal,
validateInputDataType, validateOutputDataType}
import org.apache.flink.util.Preconditions
-import AsyncCodeGenerator.{generateFunction, DEFAULT_DELEGATING_FUTURE_TERM}
-
import java.util.concurrent.CompletableFuture
import scala.collection.JavaConverters._
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingSqlFunctionCallGen.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingSqlFunctionCallGen.scala
index cada84a1bca..83223e63200 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingSqlFunctionCallGen.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/BridgingSqlFunctionCallGen.scala
@@ -26,7 +26,7 @@ import
org.apache.flink.table.planner.functions.inference.OperatorBindingCallCon
import org.apache.flink.table.runtime.collector.WrappingCollector
import org.apache.flink.table.types.logical.LogicalType
-import org.apache.calcite.rex.{RexCall, RexCallBinding}
+import org.apache.calcite.rex.{RexCall, RexCallBinding, RexProgram}
import java.util.Collections
@@ -37,7 +37,7 @@ import java.util.Collections
* generator will be a reference to a [[WrappingCollector]]. Furthermore,
atomic types are wrapped
* into a row by the collector.
*/
-class BridgingSqlFunctionCallGen(call: RexCall) extends CallGenerator {
+class BridgingSqlFunctionCallGen(call: RexCall, rexProgram: RexProgram)
extends CallGenerator {
override def generate(
ctx: CodeGeneratorContext,
@@ -54,7 +54,7 @@ class BridgingSqlFunctionCallGen(call: RexCall) extends
CallGenerator {
val callContext = new OperatorBindingCallContext(
dataTypeFactory,
definition,
- RexCallBinding.create(function.getTypeFactory, call,
Collections.emptyList()),
+ RexCallBinding.create(function.getTypeFactory, call, rexProgram,
Collections.emptyList()),
call.getType)
// create the final UDF for runtime
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonArrayCallGen.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonArrayCallGen.scala
index ea324ec6082..524e1cb7635 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonArrayCallGen.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonArrayCallGen.scala
@@ -19,16 +19,17 @@ package org.apache.flink.table.planner.codegen.calls
import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.node.{ArrayNode,
NullNode}
import org.apache.flink.table.api.JsonOnNull
-import org.apache.flink.table.planner.codegen.{CodeGeneratorContext,
GeneratedExpression}
+import org.apache.flink.table.planner.codegen.{CodeGeneratorContext,
CodeGenUtils, GeneratedExpression}
import org.apache.flink.table.planner.codegen.CodeGenUtils.{className,
newName, primitiveTypeTermForType, BINARY_STRING}
import
org.apache.flink.table.planner.codegen.JsonGenerateUtils.{createNodeTerm,
getOnNullBehavior}
import org.apache.flink.table.runtime.functions.SqlJsonUtils
import org.apache.flink.table.types.logical.LogicalType
-import org.apache.calcite.rex.RexCall
+import org.apache.calcite.rex.{RexCall, RexProgram}
/** [[CallGenerator]] for `JSON_ARRAY`. */
-class JsonArrayCallGen(call: RexCall) extends CallGenerator {
+class JsonArrayCallGen(call: RexCall, rexProgram: RexProgram) extends
CallGenerator {
+
private def jsonUtils = className[SqlJsonUtils]
override def generate(
@@ -47,7 +48,9 @@ class JsonArrayCallGen(call: RexCall) extends CallGenerator {
.drop(1)
.map {
case (elementExpr, elementIdx) =>
- val elementTerm = createNodeTerm(ctx, elementExpr,
call.operands.get(elementIdx))
+ val exprs = CodeGenUtils.getExprsFromProgramOrNull(rexProgram)
+ val elementTerm =
+ createNodeTerm(ctx, elementExpr, call.operands.get(elementIdx),
exprs)
onNull match {
case JsonOnNull.NULL =>
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonObjectCallGen.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonObjectCallGen.scala
index 9a5c87fb06f..37d126776c4 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonObjectCallGen.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonObjectCallGen.scala
@@ -19,13 +19,13 @@ package org.apache.flink.table.planner.codegen.calls
import
org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.node.{NullNode,
ObjectNode}
import org.apache.flink.table.api.JsonOnNull
-import org.apache.flink.table.planner.codegen.{CodeGeneratorContext,
GeneratedExpression}
+import org.apache.flink.table.planner.codegen.{CodeGeneratorContext,
CodeGenUtils, GeneratedExpression}
import org.apache.flink.table.planner.codegen.CodeGenUtils._
import
org.apache.flink.table.planner.codegen.JsonGenerateUtils.{createNodeTerm,
getOnNullBehavior}
import org.apache.flink.table.runtime.functions.SqlJsonUtils
import org.apache.flink.table.types.logical.LogicalType
-import org.apache.calcite.rex.RexCall
+import org.apache.calcite.rex.{RexCall, RexProgram}
/**
* [[CallGenerator]] for `JSON_OBJECT`.
@@ -37,7 +37,8 @@ import org.apache.calcite.rex.RexCall
* We remedy this by treating nested calls to this function differently and
inserting the value as a
* raw node instead of as a string node.
*/
-class JsonObjectCallGen(call: RexCall) extends CallGenerator {
+class JsonObjectCallGen(call: RexCall, rexProgram: RexProgram) extends
CallGenerator {
+
private def jsonUtils = className[SqlJsonUtils]
override def generate(
@@ -57,7 +58,8 @@ class JsonObjectCallGen(call: RexCall) extends CallGenerator {
.grouped(2)
.map {
case Seq((keyExpr, _), (valueExpr, valueIdx)) =>
- val valueTerm = createNodeTerm(ctx, valueExpr,
call.operands.get(valueIdx))
+ val exprs = CodeGenUtils.getExprsFromProgramOrNull(rexProgram)
+ val valueTerm = createNodeTerm(ctx, valueExpr,
call.operands.get(valueIdx), exprs)
onNull match {
case JsonOnNull.NULL =>
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonStringCallGen.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonStringCallGen.scala
index 1265d11a278..3655943327b 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonStringCallGen.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/JsonStringCallGen.scala
@@ -17,16 +17,17 @@
*/
package org.apache.flink.table.planner.codegen.calls
-import org.apache.flink.table.planner.codegen.{CodeGeneratorContext,
GeneratedExpression}
+import org.apache.flink.table.planner.codegen.{CodeGeneratorContext,
CodeGenUtils, GeneratedExpression}
import org.apache.flink.table.planner.codegen.CodeGenUtils.{className,
newName, primitiveTypeTermForType, BINARY_STRING}
import org.apache.flink.table.planner.codegen.JsonGenerateUtils.createNodeTerm
import org.apache.flink.table.runtime.functions.SqlJsonUtils
import org.apache.flink.table.types.logical.LogicalType
-import org.apache.calcite.rex.RexCall
+import org.apache.calcite.rex.{RexCall, RexProgram}
/** [[CallGenerator]] for `JSON_STRING`. */
-class JsonStringCallGen(call: RexCall) extends CallGenerator {
+class JsonStringCallGen(call: RexCall, rexProgram: RexProgram) extends
CallGenerator {
+
private def jsonUtils = className[SqlJsonUtils]
override def generate(
@@ -34,7 +35,8 @@ class JsonStringCallGen(call: RexCall) extends CallGenerator {
operands: Seq[GeneratedExpression],
returnType: LogicalType): GeneratedExpression = {
- val valueTerm = createNodeTerm(ctx, operands.head, call.operands.get(0))
+ val exprs = CodeGenUtils.getExprsFromProgramOrNull(rexProgram)
+ val valueTerm = createNodeTerm(ctx, operands.head, call.operands.get(0),
exprs)
val resultTerm = newName(ctx, "result")
val resultTermType = primitiveTypeTermForType(returnType)
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/SearchOperatorGen.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/SearchOperatorGen.scala
index de55d1f8c30..0f3e36e6417 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/SearchOperatorGen.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/calls/SearchOperatorGen.scala
@@ -27,7 +27,7 @@ import
org.apache.flink.table.planner.plan.utils.RexLiteralUtil.toFlinkInternalV
import org.apache.flink.table.types.logical.{BooleanType, LogicalType}
import
org.apache.flink.table.types.logical.utils.LogicalTypeMerging.findCommonType
-import org.apache.calcite.rex.{RexLiteral, RexUnknownAs}
+import org.apache.calcite.rex.{RexLiteral, RexLocalRef, RexNode, RexProgram,
RexUnknownAs}
import org.apache.calcite.util.{RangeSets, Sarg}
import java.util.Arrays.asList
@@ -53,7 +53,17 @@ object SearchOperatorGen {
def generateSearch(
ctx: CodeGeneratorContext,
target: GeneratedExpression,
- sargLiteral: RexLiteral): GeneratedExpression = {
+ rexProgram: RexProgram,
+ operands: java.util.List[RexNode]): GeneratedExpression = {
+ val sargLiteral =
+ if (rexProgram != null && operands.get(1).isInstanceOf[RexLocalRef]) {
+ rexProgram.getExprList
+ .get(operands.get(1).asInstanceOf[RexLocalRef].getIndex)
+ .asInstanceOf[RexLiteral]
+ } else {
+ operands.get(1).asInstanceOf[RexLiteral]
+ }
+
val sarg: Sarg[Nothing] = sargLiteral.getValueAs(classOf[Sarg[Nothing]])
val targetType = target.resultType
val sargType = FlinkTypeFactory.toLogicalType(sargLiteral.getType)
diff --git
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRexUtil.scala
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRexUtil.scala
index 5cdbbff533a..94d262463c6 100644
---
a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRexUtil.scala
+++
b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRexUtil.scala
@@ -525,6 +525,14 @@ object FlinkRexUtil {
RexUtil.expandSearch(rexBuilder, program,
program.expandLocalRef(program.getCondition)))
RelOptUtil.conjunctions(condition)
}
+
+ def expandLocalRef(operand: RexNode, localRefs: util.List[RexNode]): RexNode
= {
+ var expanded = operand
+ while (expanded.isInstanceOf[RexLocalRef] && localRefs != null) {
+ expanded = localRefs.get(expanded.asInstanceOf[RexLocalRef].getIndex)
+ }
+ expanded
+ }
}
/**
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/JsonFunctionsITCase.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/JsonFunctionsITCase.java
index 6c236ccd04d..1ec9561b250 100644
---
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/JsonFunctionsITCase.java
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/functions/JsonFunctionsITCase.java
@@ -41,7 +41,6 @@ import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.time.LocalDateTime;
import java.util.ArrayList;
-import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
@@ -94,6 +93,7 @@ class JsonFunctionsITCase extends BuiltInFunctionTestBase {
testCases.addAll(jsonQuoteSpec());
testCases.addAll(jsonUnquoteSpecWithValidInput());
testCases.addAll(jsonUnquoteSpecWithInvalidInput());
+ testCases.addAll(jsonLocalRefReuseSpec());
return testCases.stream();
}
@@ -296,7 +296,7 @@ class JsonFunctionsITCase extends BuiltInFunctionTestBase {
}
private static List<TestSetSpec> isJsonSpec() {
- return Arrays.asList(
+ return List.of(
TestSetSpec.forFunction(BuiltInFunctionDefinitions.IS_JSON)
.onFieldsWithData(1)
.andDataTypes(INT())
@@ -367,7 +367,7 @@ class JsonFunctionsITCase extends BuiltInFunctionTestBase {
private static List<TestSetSpec> jsonQuerySpec() {
final String jsonValue = getJsonFromResource("/json/json-query.json");
- return Arrays.asList(
+ return List.of(
TestSetSpec.forFunction(BuiltInFunctionDefinitions.JSON_QUERY)
.onFieldsWithData((String) null)
.andDataTypes(STRING())
@@ -599,7 +599,7 @@ class JsonFunctionsITCase extends BuiltInFunctionTestBase {
multisetData.put("M1", 1);
multisetData.put("M2", 2);
- return Arrays.asList(
+ return List.of(
TestSetSpec.forFunction(BuiltInFunctionDefinitions.JSON_STRING)
.onFieldsWithData(0)
.testResult(
@@ -616,7 +616,7 @@ class JsonFunctionsITCase extends BuiltInFunctionTestBase {
1.23,
LocalDateTime.parse("1990-06-02T13:37:42.001"),
Instant.parse("1990-06-02T13:37:42.001Z"),
- Arrays.asList("A1", "A2", "A3"),
+ List.of("A1", "A2", "A3"),
Row.of("R1",
Instant.parse("1990-06-02T13:37:42.001Z")),
mapData,
multisetData,
@@ -717,7 +717,7 @@ class JsonFunctionsITCase extends BuiltInFunctionTestBase {
}
private static List<TestSetSpec> jsonSpec() {
- return Arrays.asList(
+ return List.of(
TestSetSpec.forFunction(BuiltInFunctionDefinitions.JSON_OBJECT)
.onFieldsWithData("{\"key\":\"value\"}", "{\"key\":
{\"value\": 42}}")
.andDataTypes(STRING(), STRING())
@@ -946,7 +946,7 @@ class JsonFunctionsITCase extends BuiltInFunctionTestBase {
multisetData.put("M1", 1);
multisetData.put("M2", 2);
- return Arrays.asList(
+ return List.of(
TestSetSpec.forFunction(BuiltInFunctionDefinitions.JSON_OBJECT)
.onFieldsWithData(0)
.testResult(
@@ -977,7 +977,7 @@ class JsonFunctionsITCase extends BuiltInFunctionTestBase {
1.23,
LocalDateTime.parse("1990-06-02T13:37:42.001"),
Instant.parse("1990-06-02T13:37:42.001Z"),
- Arrays.asList("A1", "A2", "A3"),
+ List.of("A1", "A2", "A3"),
Row.of("R1",
Instant.parse("1990-06-02T13:37:42.001Z")),
mapData,
multisetData,
@@ -1105,7 +1105,7 @@ class JsonFunctionsITCase extends BuiltInFunctionTestBase
{
private static List<TestSetSpec> jsonQuoteSpec() {
- return Arrays.asList(
+ return List.of(
TestSetSpec.forFunction(BuiltInFunctionDefinitions.JSON_QUOTE)
.onFieldsWithData(0)
.testResult(
@@ -1177,7 +1177,7 @@ class JsonFunctionsITCase extends BuiltInFunctionTestBase
{
private static List<TestSetSpec> jsonUnquoteSpecWithValidInput() {
- return Arrays.asList(
+ return List.of(
TestSetSpec.forFunction(BuiltInFunctionDefinitions.JSON_UNQUOTE)
.onFieldsWithData(0)
.testResult(
@@ -1319,7 +1319,7 @@ class JsonFunctionsITCase extends BuiltInFunctionTestBase
{
private static List<TestSetSpec> jsonUnquoteSpecWithInvalidInput() {
- return Arrays.asList(
+ return List.of(
TestSetSpec.forFunction(BuiltInFunctionDefinitions.JSON_UNQUOTE)
.onFieldsWithData(0)
.testResult(
@@ -1406,7 +1406,7 @@ class JsonFunctionsITCase extends BuiltInFunctionTestBase
{
multisetData.put("M1", 1);
multisetData.put("M2", 2);
- return Arrays.asList(
+ return List.of(
TestSetSpec.forFunction(BuiltInFunctionDefinitions.JSON_ARRAY)
.onFieldsWithData(0)
.testResult(
@@ -1436,7 +1436,7 @@ class JsonFunctionsITCase extends BuiltInFunctionTestBase
{
1.23,
LocalDateTime.parse("1990-06-02T13:37:42.001"),
Instant.parse("1990-06-02T13:37:42.001Z"),
- Arrays.asList("A1", "A2", "A3"),
+ List.of("A1", "A2", "A3"),
Row.of("R1",
Instant.parse("1990-06-02T13:37:42.001Z")),
mapData,
multisetData,
@@ -1540,6 +1540,182 @@ class JsonFunctionsITCase extends
BuiltInFunctionTestBase {
STRING().notNull()));
}
+ /** Pins the local-ref / common-sub-expression handling for JSON
construction calls. */
+ private static List<TestSetSpec> jsonLocalRefReuseSpec() {
+ return List.of(
+ // Shared JSON(f) inside two JSON_OBJECT projections.
+ TestSetSpec.forFunction(
+ BuiltInFunctionDefinitions.JSON_OBJECT,
+ "Shared JSON(f) sub-expression across
JSON_OBJECT projections")
+ .onFieldsWithData("[1,2,3]")
+ .andDataTypes(STRING())
+ .testResult(
+ resultSpec(
+ jsonObject(JsonOnNull.NULL, "k1",
json($("f0"))),
+ "JSON_OBJECT(KEY 'k1' VALUE JSON(f0))",
+ "{\"k1\":[1,2,3]}",
+ STRING().notNull(),
+ STRING().notNull()),
+ resultSpec(
+ jsonObject(JsonOnNull.NULL, "k2",
json($("f0"))),
+ "JSON_OBJECT(KEY 'k2' VALUE JSON(f0))",
+ "{\"k2\":[1,2,3]}",
+ STRING().notNull(),
+ STRING().notNull()))
+ .testSqlResult(
+ "JSON_OBJECT(KEY 'k1' VALUE JSON(f0)),"
+ + " JSON_OBJECT(KEY 'k2' VALUE
JSON(f0))",
+ List.of("{\"k1\":[1,2,3]}",
"{\"k2\":[1,2,3]}"),
+ List.of(STRING().notNull(),
STRING().notNull())),
+ // Shared JSON_ARRAY(...) inside two JSON_OBJECT projections.
+ TestSetSpec.forFunction(
+ BuiltInFunctionDefinitions.JSON_OBJECT,
+ "Shared JSON_ARRAY sub-expression across
JSON_OBJECT projections")
+ .onFieldsWithData(1, 2, 3)
+ .andDataTypes(INT(), INT(), INT())
+ .testResult(
+ resultSpec(
+ jsonObject(
+ JsonOnNull.NULL,
+ "a",
+ jsonArray(
+ JsonOnNull.NULL,
+ $("f0"),
+ $("f1"),
+ $("f2"))),
+ "JSON_OBJECT(KEY 'a' VALUE
JSON_ARRAY(f0, f1, f2))",
+ "{\"a\":[1,2,3]}",
+ STRING().notNull(),
+ STRING().notNull()),
+ resultSpec(
+ jsonObject(
+ JsonOnNull.NULL,
+ "b",
+ jsonArray(
+ JsonOnNull.NULL,
+ $("f0"),
+ $("f1"),
+ $("f2"))),
+ "JSON_OBJECT(KEY 'b' VALUE
JSON_ARRAY(f0, f1, f2))",
+ "{\"b\":[1,2,3]}",
+ STRING().notNull(),
+ STRING().notNull()))
+ .testSqlResult(
+ "JSON_OBJECT(KEY 'a' VALUE JSON_ARRAY(f0, f1,
f2)),"
+ + " JSON_OBJECT(KEY 'b' VALUE
JSON_ARRAY(f0, f1, f2))",
+ List.of("{\"a\":[1,2,3]}", "{\"b\":[1,2,3]}"),
+ List.of(STRING().notNull(),
STRING().notNull())),
+ // Shared inner JSON_OBJECT inside two outer JSON_OBJECT
projections.
+ TestSetSpec.forFunction(
+ BuiltInFunctionDefinitions.JSON_OBJECT,
+ "Shared inner JSON_OBJECT across outer
JSON_OBJECT projections")
+ .onFieldsWithData("V")
+ .andDataTypes(STRING())
+ .testResult(
+ resultSpec(
+ jsonObject(
+ JsonOnNull.NULL,
+ "outer1",
+ jsonObject(JsonOnNull.NULL,
"inner", $("f0"))),
+ "JSON_OBJECT(KEY 'outer1' VALUE
JSON_OBJECT(KEY 'inner' VALUE f0))",
+ "{\"outer1\":{\"inner\":\"V\"}}",
+ STRING().notNull(),
+ STRING().notNull()),
+ resultSpec(
+ jsonObject(
+ JsonOnNull.NULL,
+ "outer2",
+ jsonObject(JsonOnNull.NULL,
"inner", $("f0"))),
+ "JSON_OBJECT(KEY 'outer2' VALUE
JSON_OBJECT(KEY 'inner' VALUE f0))",
+ "{\"outer2\":{\"inner\":\"V\"}}",
+ STRING().notNull(),
+ STRING().notNull()))
+ .testSqlResult(
+ "JSON_OBJECT(KEY 'outer1' VALUE
JSON_OBJECT(KEY 'inner' VALUE f0)),"
+ + " JSON_OBJECT(KEY 'outer2' VALUE
JSON_OBJECT(KEY 'inner' VALUE f0))",
+ List.of(
+ "{\"outer1\":{\"inner\":\"V\"}}",
+ "{\"outer2\":{\"inner\":\"V\"}}"),
+ List.of(STRING().notNull(),
STRING().notNull())),
+ // Shared JSON_OBJECT inside two JSON_ARRAY projections.
+ TestSetSpec.forFunction(
+ BuiltInFunctionDefinitions.JSON_ARRAY,
+ "Shared JSON_OBJECT inside JSON_ARRAY across
projections")
+ .onFieldsWithData("V")
+ .andDataTypes(STRING())
+ .testResult(
+ resultSpec(
+ jsonArray(
+ JsonOnNull.NULL,
+ jsonObject(JsonOnNull.NULL,
"k", $("f0"))),
+ "JSON_ARRAY(JSON_OBJECT(KEY 'k' VALUE
f0))",
+ "[{\"k\":\"V\"}]",
+ STRING().notNull(),
+ STRING().notNull()),
+ resultSpec(
+ jsonArray(
+ JsonOnNull.NULL,
+ jsonObject(JsonOnNull.NULL,
"k", $("f0"))),
+ "JSON_ARRAY(JSON_OBJECT(KEY 'k' VALUE
f0))",
+ "[{\"k\":\"V\"}]",
+ STRING().notNull(),
+ STRING().notNull()))
+ .testSqlResult(
+ "JSON_ARRAY(JSON_OBJECT(KEY 'k' VALUE f0)),"
+ + " JSON_ARRAY(JSON_OBJECT(KEY 'k'
VALUE f0))",
+ List.of("[{\"k\":\"V\"}]", "[{\"k\":\"V\"}]"),
+ List.of(STRING().notNull(),
STRING().notNull())),
+ // Shared JSON(f) inside two JSON_ARRAY projections.
+ TestSetSpec.forFunction(
+ BuiltInFunctionDefinitions.JSON_ARRAY,
+ "Shared JSON(f) inside JSON_ARRAY across
projections")
+ .onFieldsWithData("[1,2,3]")
+ .andDataTypes(STRING())
+ .testResult(
+ resultSpec(
+ jsonArray(JsonOnNull.NULL,
json($("f0"))),
+ "JSON_ARRAY(JSON(f0))",
+ "[[1,2,3]]",
+ STRING().notNull(),
+ STRING().notNull()),
+ resultSpec(
+ jsonArray(JsonOnNull.NULL,
json($("f0"))),
+ "JSON_ARRAY(JSON(f0))",
+ "[[1,2,3]]",
+ STRING().notNull(),
+ STRING().notNull()))
+ .testSqlResult(
+ "JSON_ARRAY(JSON(f0)), JSON_ARRAY(JSON(f0))",
+ List.of("[[1,2,3]]", "[[1,2,3]]"),
+ List.of(STRING().notNull(),
STRING().notNull())),
+ // Shared JSON_OBJECT inside two JSON_STRING projections.
JSON_STRING re-serializes
+ // the operand; without dereferencing the local ref it would
wrap the already
+ // serialized JSON string a second time.
+ TestSetSpec.forFunction(
+ BuiltInFunctionDefinitions.JSON_STRING,
+ "Shared JSON_OBJECT inside JSON_STRING across
projections")
+ .onFieldsWithData("V")
+ .andDataTypes(STRING())
+ .testResult(
+ resultSpec(
+ jsonString(jsonObject(JsonOnNull.NULL,
"k", $("f0"))),
+ "JSON_STRING(JSON_OBJECT(KEY 'k' VALUE
f0))",
+ "{\"k\":\"V\"}",
+ STRING().notNull(),
+ STRING().notNull()),
+ resultSpec(
+ jsonString(jsonObject(JsonOnNull.NULL,
"k", $("f0"))),
+ "JSON_STRING(JSON_OBJECT(KEY 'k' VALUE
f0))",
+ "{\"k\":\"V\"}",
+ STRING().notNull(),
+ STRING().notNull()))
+ .testSqlResult(
+ "JSON_STRING(JSON_OBJECT(KEY 'k' VALUE f0)),"
+ + " JSON_STRING(JSON_OBJECT(KEY 'k'
VALUE f0))",
+ List.of("{\"k\":\"V\"}", "{\"k\":\"V\"}"),
+ List.of(STRING().notNull(),
STRING().notNull())));
+ }
+
//
---------------------------------------------------------------------------------------------
/**
diff --git
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/sql/FunctionITCase.java
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/sql/FunctionITCase.java
index c0246b6fe41..c9d31af6e97 100644
---
a/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/sql/FunctionITCase.java
+++
b/flink-table/flink-table-planner/src/test/java/org/apache/flink/table/planner/runtime/stream/sql/FunctionITCase.java
@@ -69,6 +69,8 @@ import org.apache.flink.util.UserClassLoaderJarTestUtils;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
import org.junit.jupiter.params.provider.ValueSource;
import java.lang.invoke.MethodHandle;
@@ -76,7 +78,6 @@ import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.time.DayOfWeek;
import java.time.LocalDateTime;
-import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
@@ -85,7 +86,9 @@ import java.util.Optional;
import java.util.Random;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
+import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
+import java.util.stream.Stream;
import static org.apache.flink.table.api.Expressions.$;
import static
org.apache.flink.table.utils.UserDefinedFunctions.GENERATED_LOWER_UDF_CLASS;
@@ -131,10 +134,10 @@ public class FunctionITCase extends StreamingTestBase {
void testCreateCatalogFunctionInDefaultCatalog() {
String ddl1 = "create function f1 as
'org.apache.flink.function.TestFunction'";
tEnv().executeSql(ddl1);
- assertThat(Arrays.asList(tEnv().listFunctions())).contains("f1");
+ assertThat(List.of(tEnv().listFunctions())).contains("f1");
tEnv().executeSql("DROP FUNCTION IF EXISTS
default_catalog.default_database.f1");
- assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f1");
+ assertThat(List.of(tEnv().listFunctions())).doesNotContain("f1");
}
@Test
@@ -143,10 +146,10 @@ public class FunctionITCase extends StreamingTestBase {
"create function default_catalog.default_database.f2 as"
+ " 'org.apache.flink.function.TestFunction'";
tEnv().executeSql(ddl1);
- assertThat(Arrays.asList(tEnv().listFunctions())).contains("f2");
+ assertThat(List.of(tEnv().listFunctions())).contains("f2");
tEnv().executeSql("DROP FUNCTION IF EXISTS
default_catalog.default_database.f2");
- assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f2");
+ assertThat(List.of(tEnv().listFunctions())).doesNotContain("f2");
}
@Test
@@ -155,10 +158,10 @@ public class FunctionITCase extends StreamingTestBase {
"create function default_database.f3 as"
+ " 'org.apache.flink.function.TestFunction'";
tEnv().executeSql(ddl1);
- assertThat(Arrays.asList(tEnv().listFunctions())).contains("f3");
+ assertThat(List.of(tEnv().listFunctions())).contains("f3");
tEnv().executeSql("DROP FUNCTION IF EXISTS
default_catalog.default_database.f3");
- assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f3");
+ assertThat(List.of(tEnv().listFunctions())).doesNotContain("f3");
}
@Test
@@ -186,7 +189,7 @@ public class FunctionITCase extends StreamingTestBase {
+ " CURRENT_DATE = CURRENT_DATE()")
.execute();
List<Row> actualRows =
CollectionUtil.iteratorToList(tableResult.collect());
- assertThat(actualRows).isEqualTo(Arrays.asList(Row.of(true, true,
true, true, true)));
+ assertThat(actualRows).isEqualTo(List.of(Row.of(true, true, true,
true, true)));
}
@Test
@@ -232,13 +235,13 @@ public class FunctionITCase extends StreamingTestBase {
String ddl4 = "drop temporary function if exists
default_catalog.default_database.f4";
tEnv().executeSql(ddl1);
- assertThat(Arrays.asList(tEnv().listFunctions())).contains("f4");
+ assertThat(List.of(tEnv().listFunctions())).contains("f4");
tEnv().executeSql(ddl2);
- assertThat(Arrays.asList(tEnv().listFunctions())).contains("f4");
+ assertThat(List.of(tEnv().listFunctions())).contains("f4");
tEnv().executeSql(ddl3);
- assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f4");
+ assertThat(List.of(tEnv().listFunctions())).doesNotContain("f4");
tEnv().executeSql(ddl1);
assertThatThrownBy(() -> tEnv().executeSql(ddl1))
@@ -276,24 +279,24 @@ public class FunctionITCase extends StreamingTestBase {
"CREATE TEMPORARY SYSTEM FUNCTION f10 AS '%s' USING
JAR '%s'",
udfClassName, jarPath);
tEnv().executeSql(ddl);
- assertThat(Arrays.asList(tEnv().listFunctions())).contains("f10");
+ assertThat(List.of(tEnv().listFunctions())).contains("f10");
try (CloseableIterator<Row> itor = tEnv().executeSql("SHOW
JARS").collect()) {
assertThat(itor.hasNext()).isFalse();
}
tEnv().executeSql("DROP TEMPORARY SYSTEM FUNCTION f10");
-
assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f10");
+ assertThat(List.of(tEnv().listFunctions())).doesNotContain("f10");
}
@Test
void testCreateTemporarySystemFunctionWithTableAPI() {
ResourceUri resourceUri = new ResourceUri(ResourceType.JAR, jarPath);
- tEnv().createTemporarySystemFunction("f10", udfClassName,
Arrays.asList(resourceUri));
- assertThat(Arrays.asList(tEnv().listFunctions())).contains("f10");
+ tEnv().createTemporarySystemFunction("f10", udfClassName,
List.of(resourceUri));
+ assertThat(List.of(tEnv().listFunctions())).contains("f10");
tEnv().executeSql("DROP TEMPORARY SYSTEM FUNCTION f10");
-
assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f10");
+ assertThat(List.of(tEnv().listFunctions())).doesNotContain("f10");
}
@Test
@@ -303,7 +306,7 @@ public class FunctionITCase extends StreamingTestBase {
testUserDefinedFunctionByUsingJar(
environment ->
environment.createTemporarySystemFunction(
- "lowerUdf", udfClassName,
Arrays.asList(resourceUri)),
+ "lowerUdf", udfClassName,
List.of(resourceUri)),
dropFunctionSql);
}
@@ -314,20 +317,20 @@ public class FunctionITCase extends StreamingTestBase {
"CREATE FUNCTION default_database.f11 AS '%s' USING
JAR '%s'",
udfClassName, jarPath);
tEnv().executeSql(ddl);
- assertThat(Arrays.asList(tEnv().listFunctions())).contains("f11");
+ assertThat(List.of(tEnv().listFunctions())).contains("f11");
tEnv().executeSql("DROP FUNCTION default_database.f11");
-
assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f11");
+ assertThat(List.of(tEnv().listFunctions())).doesNotContain("f11");
}
@Test
void testCreateCatalogFunctionWithTableAPI() {
ResourceUri resourceUri = new ResourceUri(ResourceType.JAR, jarPath);
- tEnv().createFunction("f11", udfClassName, Arrays.asList(resourceUri));
- assertThat(Arrays.asList(tEnv().listFunctions())).contains("f11");
+ tEnv().createFunction("f11", udfClassName, List.of(resourceUri));
+ assertThat(List.of(tEnv().listFunctions())).contains("f11");
tEnv().executeSql("DROP FUNCTION default_database.f11");
-
assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f11");
+ assertThat(List.of(tEnv().listFunctions())).doesNotContain("f11");
}
@Test
@@ -336,8 +339,7 @@ public class FunctionITCase extends StreamingTestBase {
String dropFunctionSql = "DROP FUNCTION default_database.lowerUdf";
testUserDefinedFunctionByUsingJar(
environment ->
- environment.createFunction(
- "lowerUdf", udfClassName,
Arrays.asList(resourceUri)),
+ environment.createFunction("lowerUdf", udfClassName,
List.of(resourceUri)),
dropFunctionSql);
}
@@ -348,20 +350,20 @@ public class FunctionITCase extends StreamingTestBase {
"CREATE TEMPORARY FUNCTION default_database.f12 AS
'%s' USING JAR '%s'",
udfClassName, jarPath);
tEnv().executeSql(ddl);
- assertThat(Arrays.asList(tEnv().listFunctions())).contains("f12");
+ assertThat(List.of(tEnv().listFunctions())).contains("f12");
tEnv().executeSql("DROP TEMPORARY FUNCTION default_database.f12");
-
assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f12");
+ assertThat(List.of(tEnv().listFunctions())).doesNotContain("f12");
}
@Test
void testCreateTemporaryCatalogFunctionWithTableAPI() {
ResourceUri resourceUri = new ResourceUri(ResourceType.JAR, jarPath);
- tEnv().createTemporaryFunction("f12", udfClassName,
Arrays.asList(resourceUri));
- assertThat(Arrays.asList(tEnv().listFunctions())).contains("f12");
+ tEnv().createTemporaryFunction("f12", udfClassName,
List.of(resourceUri));
+ assertThat(List.of(tEnv().listFunctions())).contains("f12");
tEnv().executeSql("DROP TEMPORARY FUNCTION default_database.f12");
-
assertThat(Arrays.asList(tEnv().listFunctions())).doesNotContain("f12");
+ assertThat(List.of(tEnv().listFunctions())).doesNotContain("f12");
}
@Test
@@ -371,7 +373,7 @@ public class FunctionITCase extends StreamingTestBase {
testUserDefinedFunctionByUsingJar(
environment ->
environment.createTemporaryFunction(
- "lowerUdf", udfClassName,
Arrays.asList(resourceUri)),
+ "lowerUdf", udfClassName,
List.of(resourceUri)),
dropFunctionSql);
}
@@ -596,7 +598,7 @@ public class FunctionITCase extends StreamingTestBase {
TableResult tableResult = tEnv().executeSql("SELECT
lowerUdf('HELLO')");
List<Row> actualRows =
CollectionUtil.iteratorToList(tableResult.collect());
- assertThat(actualRows).isEqualTo(Arrays.asList(Row.of("hello")));
+ assertThat(actualRows).isEqualTo(List.of(Row.of("hello")));
tEnv().executeSql("drop temporary function lowerUdf");
}
@@ -611,7 +613,7 @@ public class FunctionITCase extends StreamingTestBase {
private void testUserDefinedCatalogFunction(String createFunctionDDL)
throws Exception {
List<Row> sourceData =
- Arrays.asList(
+ List.of(
Row.of(1, "1000", 2),
Row.of(2, "1", 3),
Row.of(3, "2000", 4),
@@ -644,7 +646,7 @@ public class FunctionITCase extends StreamingTestBase {
private void testUserDefinedFunctionByUsingJar(FunctionCreator creator,
String dropFunctionDDL)
throws Exception {
List<Row> sourceData =
- Arrays.asList(
+ List.of(
Row.of(1, "JARK"),
Row.of(2, "RON"),
Row.of(3, "LeoNard"),
@@ -667,7 +669,7 @@ public class FunctionITCase extends StreamingTestBase {
List<Row> result = TestCollectionTableFactory.RESULT();
List<Row> expected =
- Arrays.asList(
+ List.of(
Row.of(1, "jark"),
Row.of(2, "ron"),
Row.of(3, "leonard"),
@@ -684,10 +686,10 @@ public class FunctionITCase extends StreamingTestBase {
@Test
void testPrimitiveScalarFunction() throws Exception {
final List<Row> sourceData =
- Arrays.asList(Row.of(1, 1L, "-"), Row.of(2, 2L, "--"),
Row.of(3, 3L, "---"));
+ List.of(Row.of(1, 1L, "-"), Row.of(2, 2L, "--"), Row.of(3, 3L,
"---"));
final List<Row> sinkData =
- Arrays.asList(Row.of(1, 3L, "-"), Row.of(2, 6L, "--"),
Row.of(3, 9L, "---"));
+ List.of(Row.of(1, 3L, "-"), Row.of(2, 6L, "--"), Row.of(3, 9L,
"---"));
TestCollectionTableFactory.reset();
TestCollectionTableFactory.initData(sourceData);
@@ -738,7 +740,7 @@ public class FunctionITCase extends StreamingTestBase {
@Test
void testRowScalarFunction() throws Exception {
final List<Row> sourceData =
- Arrays.asList(
+ List.of(
Row.of(1, Row.of(1, "1")),
Row.of(2, Row.of(2, "2")),
Row.of(3, Row.of(3, "3")));
@@ -761,14 +763,14 @@ public class FunctionITCase extends StreamingTestBase {
@Test
void testComplexScalarFunction() throws Exception {
final List<Row> sourceData =
- Arrays.asList(
+ List.of(
Row.of(1, new byte[] {1, 2, 3}),
Row.of(2, new byte[] {2, 3, 4}),
Row.of(3, new byte[] {3, 4, 5}),
Row.of(null, null));
final List<Row> sinkData =
- Arrays.asList(
+ List.of(
Row.of(
1,
"1+2012-12-12 12:12:12.123456789",
@@ -834,11 +836,10 @@ public class FunctionITCase extends StreamingTestBase {
@Test
void testCustomScalarFunction() throws Exception {
final List<Row> sourceData =
- Arrays.asList(Row.of(1), Row.of(2), Row.of(3),
Row.of((Integer) null));
+ List.of(Row.of(1), Row.of(2), Row.of(3), Row.of((Integer)
null));
final List<Row> sinkData =
- Arrays.asList(
- Row.of(1, 1, 5), Row.of(2, 2, 5), Row.of(3, 3, 5),
Row.of(null, null, 5));
+ List.of(Row.of(1, 1, 5), Row.of(2, 2, 5), Row.of(3, 3, 5),
Row.of(null, null, 5));
TestCollectionTableFactory.reset();
TestCollectionTableFactory.initData(sourceData);
@@ -862,7 +863,7 @@ public class FunctionITCase extends StreamingTestBase {
@Test
void testVarArgScalarFunction() {
- final List<Row> sourceData = Arrays.asList(Row.of("Bob", 1),
Row.of("Alice", 2));
+ final List<Row> sourceData = List.of(Row.of("Bob", 1), Row.of("Alice",
2));
TestCollectionTableFactory.reset();
TestCollectionTableFactory.initData(sourceData);
@@ -890,7 +891,7 @@ public class FunctionITCase extends StreamingTestBase {
final List<Row> actual =
CollectionUtil.iteratorToList(result.collect());
final List<Row> expected =
- Arrays.asList(
+ List.of(
Row.of(
"(INT...)",
"(INT...)",
@@ -909,7 +910,7 @@ public class FunctionITCase extends StreamingTestBase {
@Test
void testRawLiteralScalarFunction() throws Exception {
final List<Row> sourceData =
- Arrays.asList(
+ List.of(
Row.of(1, DayOfWeek.MONDAY),
Row.of(2, DayOfWeek.FRIDAY),
Row.of(null, null));
@@ -968,13 +969,390 @@ public class FunctionITCase extends StreamingTestBase {
assertThat(TestCollectionTableFactory.getResult()).containsExactlyInAnyOrder(sinkData);
}
+ @ParameterizedTest(name = "{0}")
+ @MethodSource("inputForTestCalcLocalRefReuse")
+ void testCalcLocalRefReuse(
+ String sql, List<Row> expectedRows, int expectedDetCalls, int
expectedNonDetCalls) {
+ final List<Row> sourceData = List.of(Row.of("Bob"), Row.of("Alice"));
+
+ TestCollectionTableFactory.reset();
+ TestCollectionTableFactory.initData(sourceData);
+ CountingUpperScalarFunction.COUNT.set(0);
+ NonDeterministicCountingScalarFunction.COUNT.set(0);
+
+ tEnv().createTemporarySystemFunction("Det",
CountingUpperScalarFunction.class);
+ tEnv().createTemporarySystemFunction(
+ "Nondet",
NonDeterministicCountingScalarFunction.class);
+ tEnv().executeSql("CREATE TABLE SourceTable (s STRING) WITH
('connector' = 'COLLECTION')");
+
+ final List<Row> actual =
CollectionUtil.iteratorToList(tEnv().executeSql(sql).collect());
+
+ assertThat(actual).containsExactlyElementsOf(expectedRows);
+ assertThat(CountingUpperScalarFunction.COUNT.get())
+ .as("Deterministic invocations")
+ .isEqualTo(expectedDetCalls);
+ assertThat(NonDeterministicCountingScalarFunction.COUNT.get())
+ .as("Non-deterministic invocations")
+ .isEqualTo(expectedNonDetCalls);
+ }
+
+ static Stream<Arguments> inputForTestCalcLocalRefReuse() {
+ return Stream.of(
+ Arguments.of(
+ "SELECT Det(s), Det(s), Det(s) FROM SourceTable",
+ List.of(Row.of("BOB", "BOB", "BOB"), Row.of("ALICE",
"ALICE", "ALICE")),
+ 2, // expected localref calls: rows × 1 (cached)
+ 0),
+ Arguments.of(
+ "SELECT Det(s), Det(s), UPPER(s) FROM SourceTable",
+ List.of(Row.of("BOB", "BOB", "BOB"), Row.of("ALICE",
"ALICE", "ALICE")),
+ 2, // rows × 1 (cached); built-in UPPER not counted
+ 0),
+ Arguments.of(
+ "SELECT Det(Det(s)), Det(Det(s)), Det(Det(s)) FROM
SourceTable",
+ List.of(Row.of("BOB", "BOB", "BOB"), Row.of("ALICE",
"ALICE", "ALICE")),
+ 4, // rows × 2 layers
+ 0),
+ Arguments.of(
+ "SELECT Nondet(s), Nondet(s), Nondet(s) FROM
SourceTable",
+ List.of(
+ Row.of("BOB_1", "BOB_2", "BOB_3"),
+ Row.of("ALICE_4", "ALICE_5", "ALICE_6")),
+ 0,
+ 6 // rows × 3 projections
+ ),
+ Arguments.of(
+ "SELECT Nondet(Det(s)), Nondet(Det(s)), Nondet(Det(s))
FROM SourceTable",
+ List.of(
+ Row.of("BOB_1", "BOB_2", "BOB_3"),
+ Row.of("ALICE_4", "ALICE_5", "ALICE_6")),
+ 2, // rows × 1 (inner cached)
+ 6 // rows × 3 projections
+ ),
+ Arguments.of(
+ "SELECT Det(Nondet(s)), Det(Nondet(s)), Det(Nondet(s))
FROM SourceTable",
+ List.of(
+ Row.of("BOB_1", "BOB_2", "BOB_3"),
+ Row.of("ALICE_4", "ALICE_5", "ALICE_6")),
+ 6, // rows × 3 (nondet input disables cache)
+ 6 // rows × 3 projections
+ ),
+ // shared Det in filter → cached once per row
+ Arguments.of(
+ "SELECT s FROM SourceTable"
+ + " WHERE Det(s) IS NOT NULL AND Det(s) <> ''
AND Det(s) <> ' '",
+ List.of(Row.of("Bob"), Row.of("Alice")),
+ 2,
+ 0),
+ // mixed UDF + built-in
+ Arguments.of(
+ "SELECT s FROM SourceTable"
+ + " WHERE Det(s) IS NOT NULL AND Det(s) <> ''
AND UPPER(s) <> ''",
+ List.of(Row.of("Bob"), Row.of("Alice")),
+ 2,
+ 0),
+ // nested Det in filter; both layers cached
+ Arguments.of(
+ "SELECT s FROM SourceTable"
+ + " WHERE Det(Det(s)) IS NOT NULL"
+ + " AND Det(Det(s)) <> '' AND Det(Det(s)) <> '
'",
+ List.of(Row.of("Bob"), Row.of("Alice")),
+ 4,
+ 0),
+ // non-deterministic in filter — never cached
+ Arguments.of(
+ "SELECT s FROM SourceTable"
+ + " WHERE Nondet(s) IS NOT NULL"
+ + " AND Nondet(s) <> '' AND Nondet(s) <> ' '",
+ List.of(Row.of("Bob"), Row.of("Alice")),
+ 0,
+ 6),
+ // outer nondet, inner Det cached
+ Arguments.of(
+ "SELECT s FROM SourceTable"
+ + " WHERE Nondet(Det(s)) IS NOT NULL"
+ + " AND Nondet(Det(s)) <> '' AND
Nondet(Det(s)) <> ' '",
+ List.of(Row.of("Bob"), Row.of("Alice")),
+ 2,
+ 6),
+ // Det with nondet input → cache bypassed
+ Arguments.of(
+ "SELECT s FROM SourceTable"
+ + " WHERE Det(Nondet(s)) IS NOT NULL"
+ + " AND Det(Nondet(s)) <> '' AND
Det(Nondet(s)) <> ' '",
+ List.of(Row.of("Bob"), Row.of("Alice")),
+ 6,
+ 6),
+ // filter ↔ projection share via unified program
+ Arguments.of(
+ "SELECT Det(s) FROM SourceTable WHERE Det(s) = 'BOB'",
+ List.of(Row.of("BOB")),
+ 2,
+ 0),
+ Arguments.of(
+ "SELECT Det(s), Det(s) FROM SourceTable WHERE Det(s) =
'BOB'",
+ List.of(Row.of("BOB", "BOB")),
+ 2,
+ 0),
+
+ //
---------------------------------------------------------------------------
+ // JSON construction scenarios. These verify that the localref
/ RexProgram CSE
+ // cache also fires when the shared sub-expression is wrapped
inside (or itself
+ // is) a JSON_OBJECT / JSON_ARRAY / JSON_STRING call.
+ //
---------------------------------------------------------------------------
+
+ // JSON_OBJECT × 2 sharing inner Det → cached once per row.
+ Arguments.of(
+ "SELECT JSON_OBJECT(KEY 'a' VALUE Det(s)),"
+ + " JSON_OBJECT(KEY 'b' VALUE Det(s))"
+ + " FROM SourceTable",
+ List.of(
+ Row.of("{\"a\":\"BOB\"}", "{\"b\":\"BOB\"}"),
+ Row.of("{\"a\":\"ALICE\"}",
"{\"b\":\"ALICE\"}")),
+ 2, // rows × 1 (cached)
+ 0),
+ // JSON_ARRAY × 2 sharing inner Det → cached.
+ Arguments.of(
+ "SELECT JSON_ARRAY(Det(s)), JSON_ARRAY(Det(s)) FROM
SourceTable",
+ List.of(
+ Row.of("[\"BOB\"]", "[\"BOB\"]"),
+ Row.of("[\"ALICE\"]", "[\"ALICE\"]")),
+ 2,
+ 0),
+ // JSON_STRING × 2 sharing inner Det → cached.
+ Arguments.of(
+ "SELECT JSON_STRING(Det(s)), JSON_STRING(Det(s)) FROM
SourceTable",
+ List.of(Row.of("\"BOB\"", "\"BOB\""),
Row.of("\"ALICE\"", "\"ALICE\"")),
+ 2,
+ 0),
+ // Mixed JSON_OBJECT + JSON_ARRAY sharing same Det.
+ Arguments.of(
+ "SELECT JSON_OBJECT(KEY 'k' VALUE Det(s)),
JSON_ARRAY(Det(s))"
+ + " FROM SourceTable",
+ List.of(
+ Row.of("{\"k\":\"BOB\"}", "[\"BOB\"]"),
+ Row.of("{\"k\":\"ALICE\"}", "[\"ALICE\"]")),
+ 2,
+ 0),
+ // Mixed JSON_OBJECT + JSON_STRING sharing same Det.
+ Arguments.of(
+ "SELECT JSON_OBJECT(KEY 'k' VALUE Det(s)),
JSON_STRING(Det(s))"
+ + " FROM SourceTable",
+ List.of(
+ Row.of("{\"k\":\"BOB\"}", "\"BOB\""),
+ Row.of("{\"k\":\"ALICE\"}", "\"ALICE\"")),
+ 2,
+ 0),
+ // JSON_OBJECT × 3 sharing same Det → cached across all 3
sites.
+ Arguments.of(
+ "SELECT JSON_OBJECT(KEY 'a' VALUE Det(s)),"
+ + " JSON_OBJECT(KEY 'b' VALUE Det(s)),"
+ + " JSON_OBJECT(KEY 'c' VALUE Det(s))"
+ + " FROM SourceTable",
+ List.of(
+ Row.of("{\"a\":\"BOB\"}", "{\"b\":\"BOB\"}",
"{\"c\":\"BOB\"}"),
+ Row.of(
+ "{\"a\":\"ALICE\"}",
+ "{\"b\":\"ALICE\"}",
+ "{\"c\":\"ALICE\"}")),
+ 2,
+ 0),
+ // Nested Det(Det(s)) inside two JSON_OBJECT projections →
both layers cached.
+ Arguments.of(
+ "SELECT JSON_OBJECT(KEY 'a' VALUE Det(Det(s))),"
+ + " JSON_OBJECT(KEY 'b' VALUE Det(Det(s)))"
+ + " FROM SourceTable",
+ List.of(
+ Row.of("{\"a\":\"BOB\"}", "{\"b\":\"BOB\"}"),
+ Row.of("{\"a\":\"ALICE\"}",
"{\"b\":\"ALICE\"}")),
+ 4, // rows × 2 layers
+ 0),
+ // Nondet inside two JSON_OBJECT projections → never cached.
+ Arguments.of(
+ "SELECT JSON_OBJECT(KEY 'a' VALUE Nondet(s)),"
+ + " JSON_OBJECT(KEY 'b' VALUE Nondet(s))"
+ + " FROM SourceTable",
+ List.of(
+ Row.of("{\"a\":\"BOB_1\"}",
"{\"b\":\"BOB_2\"}"),
+ Row.of("{\"a\":\"ALICE_3\"}",
"{\"b\":\"ALICE_4\"}")),
+ 0,
+ 4 // rows × 2 projections
+ ),
+ // Outer Nondet, inner Det inside two JSON_OBJECT projections
— Det cached.
+ Arguments.of(
+ "SELECT JSON_OBJECT(KEY 'a' VALUE Nondet(Det(s))),"
+ + " JSON_OBJECT(KEY 'b' VALUE Nondet(Det(s)))"
+ + " FROM SourceTable",
+ List.of(
+ Row.of("{\"a\":\"BOB_1\"}",
"{\"b\":\"BOB_2\"}"),
+ Row.of("{\"a\":\"ALICE_3\"}",
"{\"b\":\"ALICE_4\"}")),
+ 2, // inner Det cached
+ 4),
+ // Outer Det, inner Nondet → outer cache disabled by nondet
operand.
+ Arguments.of(
+ "SELECT JSON_OBJECT(KEY 'a' VALUE Det(Nondet(s))),"
+ + " JSON_OBJECT(KEY 'b' VALUE Det(Nondet(s)))"
+ + " FROM SourceTable",
+ List.of(
+ Row.of("{\"a\":\"BOB_1\"}",
"{\"b\":\"BOB_2\"}"),
+ Row.of("{\"a\":\"ALICE_3\"}",
"{\"b\":\"ALICE_4\"}")),
+ 4, // outer Det not cached (nondet operand)
+ 4),
+ // Filter ↔ JSON projection share Det via unified program.
+ Arguments.of(
+ "SELECT JSON_OBJECT(KEY 'k' VALUE Det(s))"
+ + " FROM SourceTable WHERE Det(s) = 'BOB'",
+ List.of(Row.of("{\"k\":\"BOB\"}")),
+ 2,
+ 0),
+ // Shared inner JSON_OBJECT(KEY 'k' VALUE Det(s)) inside two
outer JSON_OBJECT
+ // projections — verifies CSE works when the cached node is
itself a JSON
+ // construction call (and validates the JSON helpers'
RexLocalRef deref path
+ // along the way).
+ Arguments.of(
+ "SELECT JSON_OBJECT(KEY 'outer1' VALUE JSON_OBJECT(KEY
'k' VALUE Det(s))),"
+ + " JSON_OBJECT(KEY 'outer2' VALUE
JSON_OBJECT(KEY 'k' VALUE Det(s)))"
+ + " FROM SourceTable",
+ List.of(
+ Row.of(
+ "{\"outer1\":{\"k\":\"BOB\"}}",
+ "{\"outer2\":{\"k\":\"BOB\"}}"),
+ Row.of(
+ "{\"outer1\":{\"k\":\"ALICE\"}}",
+ "{\"outer2\":{\"k\":\"ALICE\"}}")),
+ 2,
+ 0),
+ // Shared inner JSON_ARRAY(Det(s)) inside two outer
JSON_OBJECT projections.
+ Arguments.of(
+ "SELECT JSON_OBJECT(KEY 'a' VALUE JSON_ARRAY(Det(s))),"
+ + " JSON_OBJECT(KEY 'b' VALUE
JSON_ARRAY(Det(s)))"
+ + " FROM SourceTable",
+ List.of(
+ Row.of("{\"a\":[\"BOB\"]}",
"{\"b\":[\"BOB\"]}"),
+ Row.of("{\"a\":[\"ALICE\"]}",
"{\"b\":[\"ALICE\"]}")),
+ 2,
+ 0),
+ // Shared inner JSON_OBJECT(KEY 'k' VALUE Det(s)) inside two
JSON_ARRAY
+ // projections.
+ Arguments.of(
+ "SELECT JSON_ARRAY(JSON_OBJECT(KEY 'k' VALUE Det(s))),"
+ + " JSON_ARRAY(JSON_OBJECT(KEY 'k' VALUE
Det(s)))"
+ + " FROM SourceTable",
+ List.of(
+ Row.of("[{\"k\":\"BOB\"}]",
"[{\"k\":\"BOB\"}]"),
+ Row.of("[{\"k\":\"ALICE\"}]",
"[{\"k\":\"ALICE\"}]")),
+ 2,
+ 0));
+ }
+
+ @Test
+ void testLocalRefReuseForMixedArgs() {
+ final List<Row> sourceData = List.of(Row.of("Bob"), Row.of("Alice"));
+ final int callSites = 2;
+
+ TestCollectionTableFactory.reset();
+ TestCollectionTableFactory.initData(sourceData);
+ CountingUpperScalarFunction.COUNT.set(0);
+ NonDeterministicCountingScalarFunction.COUNT.set(0);
+ CountingConcat3ScalarFunction.COUNT.set(0);
+
+ tEnv().createTemporarySystemFunction("Det",
CountingUpperScalarFunction.class);
+ tEnv().createTemporarySystemFunction(
+ "Nondet",
NonDeterministicCountingScalarFunction.class);
+ tEnv().createTemporarySystemFunction("Concat3",
CountingConcat3ScalarFunction.class);
+ tEnv().executeSql("CREATE TABLE SourceTable (s STRING) WITH
('connector' = 'COLLECTION')");
+
+ final List<Row> actual =
+ CollectionUtil.iteratorToList(
+ tEnv().executeSql(
+ "SELECT Concat3(Det(s), Nondet(s),
Det(s)),"
+ + " Concat3(Det(s), Nondet(s),
Det(s))"
+ + " FROM SourceTable")
+ .collect());
+
+ assertThat(actual)
+ .containsExactly(
+ Row.of("BOB/BOB_1/BOB", "BOB/BOB_2/BOB"),
+ Row.of("ALICE/ALICE_3/ALICE", "ALICE/ALICE_4/ALICE"));
+
+
assertThat(CountingUpperScalarFunction.COUNT.get()).isEqualTo(sourceData.size());
+ assertThat(NonDeterministicCountingScalarFunction.COUNT.get())
+ .isEqualTo(sourceData.size() * callSites);
+ // Concat3 is deterministic however has non-deterministic input
+ assertThat(CountingConcat3ScalarFunction.COUNT.get())
+ .isEqualTo(sourceData.size() * callSites);
+ }
+
+ @Test
+ void testCalcSharesSubExpressionBetweenFilterAndProjection() {
+ final List<Row> sourceData =
+ List.of(Row.of("Bob"), Row.of("Bob"), Row.of("Alice"),
Row.of("Alice"));
+
+ TestCollectionTableFactory.reset();
+ TestCollectionTableFactory.initData(sourceData);
+ CountingUpperScalarFunction.COUNT.set(0);
+
+ tEnv().createTemporarySystemFunction("CountingUpper",
CountingUpperScalarFunction.class);
+ tEnv().executeSql("CREATE TABLE SourceTable (s STRING) WITH
('connector' = 'COLLECTION')");
+
+ final List<Row> actual =
+ CollectionUtil.iteratorToList(
+ tEnv().executeSql(
+ "SELECT CountingUpper(s) FROM
SourceTable"
+ + " WHERE CountingUpper(s) =
'BOB' AND CountingUpper(s) <> 'BOB2'")
+ .collect());
+
+ assertThat(actual).containsExactly(Row.of("BOB"), Row.of("BOB"));
+
+ // Filter and projection share via the unified RexProgram, so the UDF
runs once per
+ // source row regardless of how many call sites name it.
+
assertThat(CountingUpperScalarFunction.COUNT.get()).isEqualTo(sourceData.size());
+ }
+
+ /**
+ * Pins the CASE-WHEN guard interaction with the RexLocalRef cache.
+ *
+ * <p>Prior to scoped caching, RexProgramBuilder collapsed the division
{@code a / b} into a
+ * single exprList entry; the codegen visitor cached the body and {@code
+ * CalcCodeGenerator.reuseLocalRefCode()} hoisted that body to the top of
the generated method,
+ * evaluating {@code a / b} for every row regardless of the surrounding
{@code CASE WHEN b > 0}.
+ * Rows with {@code b = 0} then threw {@code
java.lang.ArithmeticException: Division undefined}
+ * — caught in the wild on TPC-DS query 34. With scoped caching the
division body lives inside
+ * the THEN-branch's generated code and never executes when the guard is
false.
+ */
+ @Test
+ void testCalcCaseGuardShortCircuit() {
+ final List<Row> sourceData =
+ List.of(Row.of(10, 0), Row.of(10, 2), Row.of(20, 0),
Row.of(30, 5), Row.of(40, 0));
+
+ TestCollectionTableFactory.reset();
+ TestCollectionTableFactory.initData(sourceData);
+ tEnv().executeSql(
+ "CREATE TABLE SourceTable (a INT, b INT) WITH
('connector' = 'COLLECTION')");
+
+ final List<Row> actual =
+ CollectionUtil.iteratorToList(
+ tEnv().executeSql(
+ "SELECT a FROM SourceTable WHERE"
+ + " (CASE WHEN b > 0"
+ + " THEN CAST(a AS
DECIMAL(7,2))"
+ + " / CAST(b AS
DECIMAL(7,2))"
+ + " ELSE NULL END) >
1.2")
+ .collect());
+
+ // Row(10,2) → 10/2 = 5.0 (>1.2)
+ // Row(30,5) → 30/5 = 6.0 (>1.2)
+ // Rows with b=0 must NOT enter the THEN-branch (the division would
fail).
+ assertThat(actual).containsExactly(Row.of(10), Row.of(30));
+ }
+
@Test
void testStructuredScalarFunction() throws Exception {
final List<Row> sourceData =
- Arrays.asList(Row.of("Bob", 42), Row.of("Alice", 12),
Row.of(null, 0));
+ List.of(Row.of("Bob", 42), Row.of("Alice", 12), Row.of(null,
0));
final List<Row> sinkData =
- Arrays.asList(
+ List.of(
Row.of("Bob 42", "Tyler"),
Row.of("Alice 12", "Tyler"),
Row.of("<<null>>", "Tyler"));
@@ -1020,11 +1398,10 @@ public class FunctionITCase extends StreamingTestBase {
@Test
void testRowTableFunction() throws Exception {
final List<Row> sourceData =
- Arrays.asList(
- Row.of("1,2,3"), Row.of("2,3,4"), Row.of("3,4,5"),
Row.of((String) null));
+ List.of(Row.of("1,2,3"), Row.of("2,3,4"), Row.of("3,4,5"),
Row.of((String) null));
final List<Row> sinkData =
- Arrays.asList(
+ List.of(
Row.of("1,2,3", new String[] {"1", "2", "3"}),
Row.of("2,3,4", new String[] {"2", "3", "4"}),
Row.of("3,4,5", new String[] {"3", "4", "5"}));
@@ -1048,10 +1425,9 @@ public class FunctionITCase extends StreamingTestBase {
@Test
void testStructuredTableFunction() throws Exception {
final List<Row> sourceData =
- Arrays.asList(Row.of("Bob", 42), Row.of("Alice", 12),
Row.of(null, 0));
+ List.of(Row.of("Bob", 42), Row.of("Alice", 12), Row.of(null,
0));
- final List<Row> sinkData =
- Arrays.asList(Row.of("Bob", 42), Row.of("Alice", 12),
Row.of(null, 0));
+ final List<Row> sinkData = List.of(Row.of("Bob", 42), Row.of("Alice",
12), Row.of(null, 0));
TestCollectionTableFactory.reset();
TestCollectionTableFactory.initData(sourceData);
@@ -1157,10 +1533,10 @@ public class FunctionITCase extends StreamingTestBase {
@Test
void testNamedArgumentsScalarFunction() throws Exception {
final List<Row> sourceData =
- Arrays.asList(Row.of(1, 2, "str1"), Row.of(3, 4, "str2"),
Row.of(5, 6, "str3"));
+ List.of(Row.of(1, 2, "str1"), Row.of(3, 4, "str2"), Row.of(5,
6, "str3"));
final List<Row> sinkData =
- Arrays.asList(Row.of(1, 2, "1: 2"), Row.of(3, 4, "3: 4"),
Row.of(5, 6, "5: 6"));
+ List.of(Row.of(1, 2, "1: 2"), Row.of(3, 4, "3: 4"), Row.of(5,
6, "5: 6"));
TestCollectionTableFactory.reset();
TestCollectionTableFactory.initData(sourceData);
@@ -1182,7 +1558,7 @@ public class FunctionITCase extends StreamingTestBase {
@Test
void testNamedParametersScalarFunctionWithOverloadedMethod() throws
Exception {
final List<Row> sourceData =
- Arrays.asList(Row.of(1, 2, "str1"), Row.of(3, 4, "str2"),
Row.of(5, 6, "str3"));
+ List.of(Row.of(1, 2, "str1"), Row.of(3, 4, "str2"), Row.of(5,
6, "str3"));
TestCollectionTableFactory.reset();
TestCollectionTableFactory.initData(sourceData);
@@ -1206,8 +1582,7 @@ public class FunctionITCase extends StreamingTestBase {
@Test
void testNamedArgumentsScalarFunctionWithOptionalArguments() throws
Exception {
- final List<Row> sinkData =
- Arrays.asList(Row.of("s1: null", "null: s2", "s1: s2", "null:
null"));
+ final List<Row> sinkData = List.of(Row.of("s1: null", "null: s2", "s1:
s2", "null: null"));
TestCollectionTableFactory.reset();
tEnv().executeSql(
@@ -1230,14 +1605,13 @@ public class FunctionITCase extends StreamingTestBase {
@Test
void testNamedArgumentAggregateFunction() throws Exception {
final List<Row> sourceData =
- Arrays.asList(
+ List.of(
Row.of(LocalDateTime.parse("2007-12-03T10:15:30"),
"a", "b", 1, 2),
Row.of(LocalDateTime.parse("2007-12-03T10:15:30"),
"c", "d", 33, 44),
Row.of(LocalDateTime.parse("2007-12-03T10:15:32"),
"e", "f", 5, 6),
Row.of(LocalDateTime.parse("2007-12-03T10:15:32"),
"gg", "hh", 7, 88));
- final List<Row> sinkData =
- Arrays.asList(Row.of("a: b", "b: a"), Row.of("gg: hh", "hh:
gg"));
+ final List<Row> sinkData = List.of(Row.of("a: b", "b: a"), Row.of("gg:
hh", "hh: gg"));
TestCollectionTableFactory.reset();
TestCollectionTableFactory.initData(sourceData);
@@ -1265,14 +1639,14 @@ public class FunctionITCase extends StreamingTestBase {
@Test
void testNamedArgumentAggregateFunctionWithOptionalArguments() throws
Exception {
final List<Row> sourceData =
- Arrays.asList(
+ List.of(
Row.of(LocalDateTime.parse("2007-12-03T10:15:30"),
"a", "b", 1, 2),
Row.of(LocalDateTime.parse("2007-12-03T10:15:30"),
"c", "d", 33, 44),
Row.of(LocalDateTime.parse("2007-12-03T10:15:32"),
"e", "f", 5, 6),
Row.of(LocalDateTime.parse("2007-12-03T10:15:32"),
"gg", "hh", 7, 88));
final List<Row> sinkData =
- Arrays.asList(Row.of("a: null", "null: b"), Row.of("gg: null",
"null: hh"));
+ List.of(Row.of("a: null", "null: b"), Row.of("gg: null",
"null: hh"));
TestCollectionTableFactory.reset();
TestCollectionTableFactory.initData(sourceData);
@@ -1346,7 +1720,7 @@ public class FunctionITCase extends StreamingTestBase {
@Test
void testAggregateFunction() throws Exception {
final List<Row> sourceData =
- Arrays.asList(
+ List.of(
Row.of(LocalDateTime.parse("2007-12-03T10:15:30"),
"Bob"),
Row.of(LocalDateTime.parse("2007-12-03T10:15:30"),
"Alice"),
Row.of(LocalDateTime.parse("2007-12-03T10:15:30"),
null),
@@ -1355,7 +1729,7 @@ public class FunctionITCase extends StreamingTestBase {
Row.of(LocalDateTime.parse("2007-12-03T10:15:32"),
"Alice"));
final List<Row> sinkData =
- Arrays.asList(
+ List.of(
Row.of(
"Jonathan",
"Alice=(Alice, 5), Bob=(Bob, 3),
Jonathan=(Jonathan, 8)"),
@@ -1409,10 +1783,10 @@ public class FunctionITCase extends StreamingTestBase {
private void testLookupTableFunctionBase(String
lookupTableFunctionClassName)
throws ExecutionException, InterruptedException {
- final List<Row> sourceData = Arrays.asList(Row.of("Bob"),
Row.of("Alice"));
+ final List<Row> sourceData = List.of(Row.of("Bob"), Row.of("Alice"));
final List<Row> sinkData =
- Arrays.asList(
+ List.of(
Row.of("Bob", new byte[0]),
Row.of("Bob", new byte[] {66, 111, 98}),
Row.of("Alice", new byte[0]),
@@ -1459,7 +1833,7 @@ public class FunctionITCase extends StreamingTestBase {
@Test
void testSpecializedFunction() {
final List<Row> sourceData =
- Arrays.asList(
+ List.of(
Row.of("Bob", 1, new BigDecimal("123.45")),
Row.of("Alice", 2, new BigDecimal("123.456")));
@@ -1489,7 +1863,7 @@ public class FunctionITCase extends StreamingTestBase {
final List<Row> actual =
CollectionUtil.iteratorToList(result.collect());
final List<Row> expected =
- Arrays.asList(
+ List.of(
Row.of("CHAR(7) NOT NULL", "STRING", "INT",
"DECIMAL(6, 3)"),
Row.of("CHAR(7) NOT NULL", "STRING", "INT",
"DECIMAL(6, 3)"));
assertThat(actual).isEqualTo(expected);
@@ -1498,7 +1872,7 @@ public class FunctionITCase extends StreamingTestBase {
@Test
void testSpecializedFunctionWithExpressionEvaluation() {
final List<Row> sourceData =
- Arrays.asList(
+ List.of(
Row.of("Bob", new Integer[] {1, 2, 3}, new
BigDecimal("123.000")),
Row.of("Bob", new Integer[] {4, 5, 6}, new
BigDecimal("123.456")),
Row.of("Alice", new Integer[] {1, 2, 3}, null),
@@ -1530,7 +1904,7 @@ public class FunctionITCase extends StreamingTestBase {
final List<Row> actual =
CollectionUtil.iteratorToList(result.collect());
final List<Row> expected =
- Arrays.asList(
+ List.of(
Row.of("Bob", null, null),
Row.of(
"Bob",
@@ -1543,7 +1917,7 @@ public class FunctionITCase extends StreamingTestBase {
@Test
void testTimestampNotNull() {
- List<Row> sourceData = Arrays.asList(Row.of(1), Row.of(2));
+ List<Row> sourceData = List.of(Row.of(1), Row.of(2));
TestCollectionTableFactory.reset();
TestCollectionTableFactory.initData(sourceData);
@@ -1557,7 +1931,7 @@ public class FunctionITCase extends StreamingTestBase {
@Test
void testIsNullType() {
- List<Row> sourceData = Arrays.asList(Row.of(1), Row.of((Object) null));
+ List<Row> sourceData = List.of(Row.of(1), Row.of((Object) null));
TestCollectionTableFactory.reset();
TestCollectionTableFactory.initData(sourceData);
@@ -1571,7 +1945,7 @@ public class FunctionITCase extends StreamingTestBase {
@Test
void testWithBoolNotNullTypeHint() {
- List<Row> sourceData = Arrays.asList(Row.of(1, 2), Row.of(2, 3));
+ List<Row> sourceData = List.of(Row.of(1, 2), Row.of(2, 3));
TestCollectionTableFactory.reset();
TestCollectionTableFactory.initData(sourceData);
@@ -1605,7 +1979,7 @@ public class FunctionITCase extends StreamingTestBase {
@Test
void testUdfWithMultiLocalVariables() {
- List<Row> sourceData = Arrays.asList(Row.of(1L, 2L), Row.of(2L, 3L));
+ List<Row> sourceData = List.of(Row.of(1L, 2L), Row.of(2L, 3L));
TestCollectionTableFactory.reset();
TestCollectionTableFactory.initData(sourceData);
@@ -1620,7 +1994,7 @@ public class FunctionITCase extends StreamingTestBase {
CollectionUtil.iteratorToList(
tEnv().executeSql("SELECT MultiLocalVariables(x, y)
FROM SourceTable")
.collect());
- assertThat(actualRows).isEqualTo(Arrays.asList(Row.of(2L),
Row.of(6L)));
+ assertThat(actualRows).isEqualTo(List.of(Row.of(2L), Row.of(6L)));
}
//
--------------------------------------------------------------------------------------------
@@ -1757,6 +2131,41 @@ public class FunctionITCase extends StreamingTestBase {
}
}
+ /** Deterministic function with a counter. */
+ public static class CountingUpperScalarFunction extends ScalarFunction {
+ public static final AtomicInteger COUNT = new AtomicInteger();
+
+ public String eval(String s) {
+ COUNT.incrementAndGet();
+ return s == null ? null : s.toUpperCase();
+ }
+ }
+
+ /** Deterministic function with a counter and 3 args. */
+ public static class CountingConcat3ScalarFunction extends ScalarFunction {
+ public static final AtomicInteger COUNT = new AtomicInteger();
+
+ public String eval(String a, String b, String c) {
+ COUNT.incrementAndGet();
+ return a + "/" + b + "/" + c;
+ }
+ }
+
+ /** Non-deterministic function with a counter. */
+ public static class NonDeterministicCountingScalarFunction extends
ScalarFunction {
+ public static final AtomicInteger COUNT = new AtomicInteger();
+
+ public String eval(String s) {
+ final int count = COUNT.incrementAndGet();
+ return s == null ? null : s.toUpperCase() + "_" + count;
+ }
+
+ @Override
+ public boolean isDeterministic() {
+ return false;
+ }
+ }
+
/** Function that has a custom type inference that is broader than the
actual implementation. */
public static class CustomScalarFunction extends ScalarFunction {
public Integer eval(Integer... args) {