This is an automated email from the ASF dual-hosted git repository. jark pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/flink.git
commit 43b8a36efc9275c3b95dd2d6961ef11e8e5d07e9 Author: JingsongLi <lzljs3620...@aliyun.com> AuthorDate: Thu Aug 22 12:57:06 2019 +0200 [FLINK-13774][table-planner-blink] Expressions of DeclarativeAggregateFunction should be resolved --- .../expressions/CallExpressionResolver.java | 57 ++++++++++ .../expressions/DeclarativeExpressionResolver.java | 95 ++++++++++++++++ .../planner/expressions/ExpressionBuilder.java | 53 +++++---- .../aggfunctions/SingleValueAggFunction.java | 12 ++- .../codegen/agg/DeclarativeAggCodeGen.scala | 120 +++++++-------------- .../codegen/agg/batch/AggCodeGenHelper.scala | 78 ++++---------- .../codegen/agg/batch/HashAggCodeGenHelper.scala | 71 ++++-------- .../codegen/agg/batch/WindowCodeGenerator.scala | 4 +- .../expressions/PlannerExpressionConverter.scala | 8 ++ .../planner/expressions/fieldExpression.scala | 55 ++++++++++ .../table/planner/plan/utils/AggregateUtil.scala | 8 +- 11 files changed, 339 insertions(+), 222 deletions(-) diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/expressions/CallExpressionResolver.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/expressions/CallExpressionResolver.java new file mode 100644 index 0000000..a46c786 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/expressions/CallExpressionResolver.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.expressions; + +import org.apache.flink.table.expressions.Expression; +import org.apache.flink.table.expressions.ResolvedExpression; +import org.apache.flink.table.expressions.UnresolvedCallExpression; +import org.apache.flink.table.expressions.resolver.ExpressionResolver; +import org.apache.flink.table.planner.calcite.FlinkContext; +import org.apache.flink.util.Preconditions; + +import org.apache.calcite.tools.RelBuilder; + +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +/** + * Planner expression resolver for {@link UnresolvedCallExpression}. + */ +public class CallExpressionResolver { + + private final ExpressionResolver resolver; + + public CallExpressionResolver(RelBuilder relBuilder) { + // dummy way to get context + FlinkContext context = (FlinkContext) relBuilder + .values(new String[]{"dummyField"}, "dummyValue") + .build() + .getCluster().getPlanner().getContext(); + this.resolver = ExpressionResolver.resolverFor( + name -> Optional.empty(), + context.getFunctionCatalog()).build(); + } + + public ResolvedExpression resolve(Expression expression) { + List<ResolvedExpression> resolved = resolver.resolve(Collections.singletonList(expression)); + Preconditions.checkArgument(resolved.size() == 1); + return resolved.get(0); + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/expressions/DeclarativeExpressionResolver.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/expressions/DeclarativeExpressionResolver.java new file mode 100644 index 0000000..6b1b429 --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/expressions/DeclarativeExpressionResolver.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.planner.expressions; + +import org.apache.flink.table.expressions.Expression; +import org.apache.flink.table.expressions.ExpressionDefaultVisitor; +import org.apache.flink.table.expressions.ResolvedExpression; +import org.apache.flink.table.expressions.UnresolvedCallExpression; +import org.apache.flink.table.expressions.UnresolvedReferenceExpression; +import org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction; + +import org.apache.calcite.tools.RelBuilder; +import org.apache.commons.lang3.ArrayUtils; + +import java.util.stream.Collectors; + +/** + * Abstract class to resolve the expressions in {@link DeclarativeAggregateFunction}. + */ +public abstract class DeclarativeExpressionResolver extends ExpressionDefaultVisitor<ResolvedExpression> { + + private final DeclarativeAggregateFunction function; + private final boolean isMerge; + private final CallExpressionResolver resolver; + + public DeclarativeExpressionResolver( + RelBuilder relBuilder, DeclarativeAggregateFunction function, boolean isMerge) { + this.function = function; + this.isMerge = isMerge; + this.resolver = new CallExpressionResolver(relBuilder); + } + + @Override + protected ResolvedExpression defaultMethod(Expression expression) { + if (expression instanceof UnresolvedReferenceExpression) { + UnresolvedReferenceExpression expr = (UnresolvedReferenceExpression) expression; + String name = expr.getName(); + int localIndex = ArrayUtils.indexOf(function.aggBufferAttributes(), expr); + if (localIndex == -1) { + // We always use UnresolvedFieldReference to represent reference of input field. + // In non-merge case, the input is operand of the aggregate function. But in merge + // case, the input is aggregate buffers which sent by local aggregate. + if (isMerge) { + return toMergeInputExpr(name, ArrayUtils.indexOf(function.mergeOperands(), expr)); + } else { + return toAccInputExpr(name, ArrayUtils.indexOf(function.operands(), expr)); + } + } else { + return toAggBufferExpr(name, localIndex); + } + } else if (expression instanceof UnresolvedCallExpression) { + UnresolvedCallExpression unresolvedCall = (UnresolvedCallExpression) expression; + return resolver.resolve(new UnresolvedCallExpression( + unresolvedCall.getFunctionDefinition(), + unresolvedCall.getChildren().stream() + .map(c -> c.accept(DeclarativeExpressionResolver.this)) + .collect(Collectors.toList()))); + } else if (expression instanceof ResolvedExpression) { + return (ResolvedExpression) expression; + } else { + return resolver.resolve(expression); + } + } + + /** + * When merge phase, for inputs. + */ + public abstract ResolvedExpression toMergeInputExpr(String name, int localIndex); + + /** + * When accumulate phase, for inputs. + */ + public abstract ResolvedExpression toAccInputExpr(String name, int localIndex); + + /** + * For aggregate buffer. + */ + public abstract ResolvedExpression toAggBufferExpr(String name, int localIndex); +} diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/expressions/ExpressionBuilder.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/expressions/ExpressionBuilder.java index d96e698..baddfd8 100644 --- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/expressions/ExpressionBuilder.java +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/expressions/ExpressionBuilder.java @@ -20,6 +20,8 @@ package org.apache.flink.table.planner.expressions; import org.apache.flink.table.expressions.Expression; import org.apache.flink.table.expressions.TypeLiteralExpression; +import org.apache.flink.table.expressions.UnresolvedCallExpression; +import org.apache.flink.table.expressions.ValueLiteralExpression; import org.apache.flink.table.expressions.utils.ApiExpressionUtils; import org.apache.flink.table.functions.FunctionDefinition; import org.apache.flink.table.types.DataType; @@ -42,92 +44,91 @@ import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.OR; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.PLUS; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.REINTERPRET_CAST; import static org.apache.flink.table.functions.BuiltInFunctionDefinitions.TIMES; -import static org.apache.flink.table.planner.functions.InternalFunctionDefinitions.THROW_EXCEPTION; /** * Builder for {@link Expression}s. */ public class ExpressionBuilder { - public static Expression nullOf(DataType type) { + public static ValueLiteralExpression nullOf(DataType type) { return literal(null, type); } - public static Expression literal(Object value) { + public static ValueLiteralExpression literal(Object value) { return ApiExpressionUtils.valueLiteral(value); } - public static Expression literal(Object value, DataType type) { + public static ValueLiteralExpression literal(Object value, DataType type) { return ApiExpressionUtils.valueLiteral(value, type); } - public static Expression call(FunctionDefinition functionDefinition, Expression... args) { + public static UnresolvedCallExpression call(FunctionDefinition functionDefinition, Expression... args) { return ApiExpressionUtils.unresolvedCall(functionDefinition, args); } - public static Expression call(FunctionDefinition functionDefinition, List<Expression> args) { + public static UnresolvedCallExpression call(FunctionDefinition functionDefinition, List<Expression> args) { return ApiExpressionUtils.unresolvedCall(functionDefinition, args.toArray(new Expression[0])); } - public static Expression and(Expression arg1, Expression arg2) { + public static UnresolvedCallExpression and(Expression arg1, Expression arg2) { return call(AND, arg1, arg2); } - public static Expression or(Expression arg1, Expression arg2) { + public static UnresolvedCallExpression or(Expression arg1, Expression arg2) { return call(OR, arg1, arg2); } - public static Expression not(Expression arg) { + public static UnresolvedCallExpression not(Expression arg) { return call(NOT, arg); } - public static Expression isNull(Expression input) { + public static UnresolvedCallExpression isNull(Expression input) { return call(IS_NULL, input); } - public static Expression ifThenElse(Expression condition, Expression ifTrue, - Expression ifFalse) { + public static UnresolvedCallExpression ifThenElse(Expression condition, Expression ifTrue, + Expression ifFalse) { return call(IF, condition, ifTrue, ifFalse); } - public static Expression plus(Expression input1, Expression input2) { + public static UnresolvedCallExpression plus(Expression input1, Expression input2) { return call(PLUS, input1, input2); } - public static Expression minus(Expression input1, Expression input2) { + public static UnresolvedCallExpression minus(Expression input1, Expression input2) { return call(MINUS, input1, input2); } - public static Expression div(Expression input1, Expression input2) { + public static UnresolvedCallExpression div(Expression input1, Expression input2) { return call(DIVIDE, input1, input2); } - public static Expression times(Expression input1, Expression input2) { + public static UnresolvedCallExpression times(Expression input1, Expression input2) { return call(TIMES, input1, input2); } - public static Expression mod(Expression input1, Expression input2) { + public static UnresolvedCallExpression mod(Expression input1, Expression input2) { return call(MOD, input1, input2); } - public static Expression equalTo(Expression input1, Expression input2) { + public static UnresolvedCallExpression equalTo(Expression input1, Expression input2) { return call(EQUALS, input1, input2); } - public static Expression lessThan(Expression input1, Expression input2) { + public static UnresolvedCallExpression lessThan(Expression input1, Expression input2) { return call(LESS_THAN, input1, input2); } - public static Expression greaterThan(Expression input1, Expression input2) { + public static UnresolvedCallExpression greaterThan(Expression input1, Expression input2) { return call(GREATER_THAN, input1, input2); } - public static Expression cast(Expression child, Expression type) { + public static UnresolvedCallExpression cast(Expression child, Expression type) { return call(CAST, child, type); } - public static Expression reinterpretCast(Expression child, Expression type, - boolean checkOverflow) { + public static UnresolvedCallExpression reinterpretCast(Expression child, Expression type, + boolean checkOverflow) { return call(REINTERPRET_CAST, child, type, literal(checkOverflow)); } @@ -135,11 +136,7 @@ public class ExpressionBuilder { return ApiExpressionUtils.typeLiteral(type); } - public static Expression concat(Expression input1, Expression input2) { + public static UnresolvedCallExpression concat(Expression input1, Expression input2) { return call(CONCAT, input1, input2); } - - public static Expression throwException(String msg, DataType type) { - return call(THROW_EXCEPTION, literal(msg), typeLiteral(type)); - } } diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SingleValueAggFunction.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SingleValueAggFunction.java index 865c0c2..24b2d44 100644 --- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SingleValueAggFunction.java +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/planner/functions/aggfunctions/SingleValueAggFunction.java @@ -19,12 +19,15 @@ package org.apache.flink.table.planner.functions.aggfunctions; import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.expressions.CallExpression; import org.apache.flink.table.expressions.Expression; import org.apache.flink.table.expressions.UnresolvedReferenceExpression; import org.apache.flink.table.types.DataType; import org.apache.flink.table.types.logical.DecimalType; import org.apache.flink.table.types.logical.TimeType; +import java.util.Arrays; + import static org.apache.flink.table.expressions.utils.ApiExpressionUtils.unresolvedRef; import static org.apache.flink.table.planner.expressions.ExpressionBuilder.equalTo; import static org.apache.flink.table.planner.expressions.ExpressionBuilder.greaterThan; @@ -34,7 +37,8 @@ import static org.apache.flink.table.planner.expressions.ExpressionBuilder.minus import static org.apache.flink.table.planner.expressions.ExpressionBuilder.nullOf; import static org.apache.flink.table.planner.expressions.ExpressionBuilder.or; import static org.apache.flink.table.planner.expressions.ExpressionBuilder.plus; -import static org.apache.flink.table.planner.expressions.ExpressionBuilder.throwException; +import static org.apache.flink.table.planner.expressions.ExpressionBuilder.typeLiteral; +import static org.apache.flink.table.planner.functions.InternalFunctionDefinitions.THROW_EXCEPTION; /** * Base class for built-in single value aggregate function. @@ -118,6 +122,12 @@ public abstract class SingleValueAggFunction extends DeclarativeAggregateFunctio return value; } + private static Expression throwException(String msg, DataType type) { + // it is the internal function without catalog. + // so it can not be find in any catalog or built-in functions. + return new CallExpression(THROW_EXCEPTION, Arrays.asList(literal(msg), typeLiteral(type)), type); + } + /** * Built-in byte single value aggregate function. */ diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/DeclarativeAggCodeGen.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/DeclarativeAggCodeGen.scala index 30efc86..a175f64 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/DeclarativeAggCodeGen.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/DeclarativeAggCodeGen.scala @@ -18,19 +18,16 @@ package org.apache.flink.table.planner.codegen.agg import org.apache.flink.table.expressions._ -import org.apache.flink.table.expressions.utils.ApiExpressionUtils import org.apache.flink.table.planner.codegen.CodeGenUtils.primitiveTypeTermForType import org.apache.flink.table.planner.codegen.agg.AggsHandlerCodeGenerator.DISTINCT_KEY_TERM import org.apache.flink.table.planner.codegen.{CodeGeneratorContext, ExprCodeGenerator, GeneratedExpression} -import org.apache.flink.table.planner.expressions.{ResolvedAggInputReference, ResolvedAggLocalReference, ResolvedDistinctKeyReference, RexNodeConverter} +import org.apache.flink.table.planner.expressions.{DeclarativeExpressionResolver, ResolvedAggInputReference, ResolvedAggLocalReference, ResolvedDistinctKeyReference, RexNodeConverter} import org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction import org.apache.flink.table.planner.plan.utils.AggregateInfo import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType import org.apache.flink.table.types.logical.LogicalType -import org.apache.calcite.tools.RelBuilder -import org.apache.flink.table.functions.BuiltInFunctionDefinitions -import scala.collection.JavaConverters._ +import org.apache.calcite.tools.RelBuilder /** * It is for code generate aggregation functions that are specified using expressions. @@ -216,87 +213,52 @@ class DeclarativeAggCodeGen( */ private case class ResolveReference( isMerge: Boolean = false, - isDistinctMerge: Boolean = false) extends ExpressionVisitor[Expression] { - - override def visit(call: CallExpression): Expression = ??? - - override def visit(valueLiteralExpression: ValueLiteralExpression): Expression = { - valueLiteralExpression - } - - override def visit(input: FieldReferenceExpression): Expression = { - input + isDistinctMerge: Boolean = false) + extends DeclarativeExpressionResolver(relBuilder, function, isMerge) { + + override def toMergeInputExpr(name: String, localIndex: Int): ResolvedExpression = { + // in merge case, the input1 is mergedAcc + new ResolvedAggInputReference( + name, + mergedAccOffset + bufferIndexes(localIndex), + bufferTypes(localIndex)) } - override def visit(typeLiteral: TypeLiteralExpression): Expression = { - typeLiteral - } - - private def visitUnresolvedCallExpression( - unresolvedCall: UnresolvedCallExpression): Expression = { - ApiExpressionUtils.unresolvedCall( - unresolvedCall.getFunctionDefinition, - unresolvedCall.getChildren.asScala.map(_.accept(this)): _*) - } - - private def visitUnresolvedReference(input: UnresolvedReferenceExpression) - : Expression = { - function.aggBufferAttributes.indexOf(input) match { - case -1 => - // Not find in agg buffers, it is a operand, represent reference of input field. - // In non-merge case, the input is the operand of the aggregate function. - // In merge case, the input is the aggregate buffers sent by local aggregate. - if (isMerge) { - val localIndex = function.mergeOperands.indexOf(input) - // in merge case, the input1 is mergedAcc - new ResolvedAggInputReference( - input.getName, - mergedAccOffset + bufferIndexes(localIndex), - bufferTypes(localIndex)) + override def toAccInputExpr(name: String, localIndex: Int): ResolvedExpression = { + val inputIndex = argIndexes(localIndex) + if (inputIndex >= inputTypes.length) { // it is a constant + val constantIndex = inputIndex - inputTypes.length + val constantTerm = constantExprs(constantIndex).resultTerm + val nullTerm = constantExprs(constantIndex).nullTerm + val constantType = constantExprs(constantIndex).resultType + // constant is reused as member variable + new ResolvedAggLocalReference( + constantTerm, + nullTerm, + constantType) + } else { // it is a input field + if (isDistinctMerge) { // this is called from distinct merge + if (function.operandCount == 1) { + // the distinct key is a BoxedValue + new ResolvedDistinctKeyReference(name, argTypes(localIndex)) } else { - val localIndex = function.operands.indexOf(input) - val inputIndex = argIndexes(localIndex) - if (inputIndex >= inputTypes.length) { // it is a constant - val constantIndex = inputIndex - inputTypes.length - val constantTerm = constantExprs(constantIndex).resultTerm - val nullTerm = constantExprs(constantIndex).nullTerm - val constantType = constantExprs(constantIndex).resultType - // constant is reused as member variable - new ResolvedAggLocalReference( - constantTerm, - nullTerm, - constantType) - } else { // it is a input field - if (isDistinctMerge) { // this is called from distinct merge - if (function.operandCount == 1) { - // the distinct key is a BoxedValue - new ResolvedDistinctKeyReference(input.getName, argTypes(localIndex)) - } else { - // the distinct key is a BaseRow - new ResolvedAggInputReference(input.getName, localIndex, argTypes(localIndex)) - } - } else { - // the input is the inputRow - new ResolvedAggInputReference( - input.getName, argIndexes(localIndex), argTypes(localIndex)) - } - } + // the distinct key is a BaseRow + new ResolvedAggInputReference(name, localIndex, argTypes(localIndex)) } - case localIndex => - // it is a agg buffer. - val name = bufferTerms(localIndex) - val nullTerm = bufferNullTerms(localIndex) - // buffer access is reused as member variable - new ResolvedAggLocalReference(name, nullTerm, bufferTypes(localIndex)) + } else { + // the input is the inputRow + new ResolvedAggInputReference( + name, argIndexes(localIndex), argTypes(localIndex)) + } } } - override def visit(other: Expression): Expression = { - other match { - case u : UnresolvedReferenceExpression => visitUnresolvedReference(u) - case u : UnresolvedCallExpression => visitUnresolvedCallExpression(u) - case _ => other - } + override def toAggBufferExpr(name: String, localIndex: Int): ResolvedExpression = { + // it is a agg buffer. + val name = bufferTerms(localIndex) + val nullTerm = bufferNullTerms(localIndex) + // buffer access is reused as member variable + new ResolvedAggLocalReference(name, nullTerm, bufferTypes(localIndex)) } } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala index 2b95509..d235f91 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/AggCodeGenHelper.scala @@ -22,13 +22,12 @@ import org.apache.flink.api.common.ExecutionConfig import org.apache.flink.runtime.util.SingleElementIterator import org.apache.flink.streaming.api.operators.OneInputStreamOperator import org.apache.flink.table.dataformat.{BaseRow, GenericRow} -import org.apache.flink.table.expressions.utils.ApiExpressionUtils -import org.apache.flink.table.expressions.{Expression, ExpressionVisitor, FieldReferenceExpression, TypeLiteralExpression, UnresolvedCallExpression, UnresolvedReferenceExpression, ValueLiteralExpression, _} +import org.apache.flink.table.expressions.{Expression, _} import org.apache.flink.table.functions.{AggregateFunction, UserDefinedFunction} import org.apache.flink.table.planner.codegen.CodeGenUtils._ import org.apache.flink.table.planner.codegen.OperatorCodeGenerator.STREAM_RECORD import org.apache.flink.table.planner.codegen._ -import org.apache.flink.table.planner.expressions.{ResolvedAggInputReference, ResolvedAggLocalReference, RexNodeConverter} +import org.apache.flink.table.planner.expressions.{DeclarativeExpressionResolver, ResolvedAggInputReference, ResolvedAggLocalReference, RexNodeConverter} import org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils.{getAccumulatorTypeOfAggregateFunction, getAggUserDefinedInputTypes} import org.apache.flink.table.runtime.context.ExecutionContextImpl @@ -38,12 +37,11 @@ import org.apache.flink.table.runtime.types.LogicalTypeDataTypeConverter.fromDat import org.apache.flink.table.types.DataType import org.apache.flink.table.types.logical.LogicalTypeRoot._ import org.apache.flink.table.types.logical.{LogicalType, RowType} + import org.apache.calcite.rel.core.AggregateCall import org.apache.calcite.rex.RexNode import org.apache.calcite.tools.RelBuilder -import scala.collection.JavaConverters._ - /** * Batch aggregate code generate helper. */ @@ -253,60 +251,28 @@ object AggCodeGenHelper { */ private case class ResolveReference( ctx: CodeGeneratorContext, + relBuilder: RelBuilder, isMerge: Boolean, agg: DeclarativeAggregateFunction, aggIndex: Int, argsMapping: Array[Array[(Int, LogicalType)]], - aggBufferTypes: Array[Array[LogicalType]]) extends ExpressionVisitor[Expression] { - - override def visit(call: CallExpression): Expression = ??? + aggBufferTypes: Array[Array[LogicalType]]) + extends DeclarativeExpressionResolver(relBuilder, agg, isMerge) { - override def visit(valueLiteralExpression: ValueLiteralExpression): Expression = { - valueLiteralExpression + override def toMergeInputExpr(name: String, localIndex: Int): ResolvedExpression = { + val (inputIndex, inputType) = argsMapping(aggIndex)(localIndex) + new ResolvedAggInputReference(name, inputIndex, inputType) } - override def visit(input: FieldReferenceExpression): Expression = { - input + override def toAccInputExpr(name: String, localIndex: Int): ResolvedExpression = { + val (inputIndex, inputType) = argsMapping(aggIndex)(localIndex) + new ResolvedAggInputReference(name, inputIndex, inputType) } - override def visit(typeLiteral: TypeLiteralExpression): Expression = { - typeLiteral - } - - private def visitUnresolvedCallExpression( - unresolvedCall: UnresolvedCallExpression): Expression = { - ApiExpressionUtils.unresolvedCall( - unresolvedCall.getFunctionDefinition, - unresolvedCall.getChildren.asScala.map(_.accept(this)): _*) - } - - private def visitUnresolvedFieldReference( - input: UnresolvedReferenceExpression): Expression = { - agg.aggBufferAttributes.indexOf(input) match { - case -1 => - // We always use UnresolvedFieldReference to represent reference of input field. - // In non-merge case, the input is operand of the aggregate function. But in merge - // case, the input is aggregate buffers which sent by local aggregate. - val localIndex = if (isMerge) { - agg.mergeOperands.indexOf(input) - } else { - agg.operands.indexOf(input) - } - val (inputIndex, inputType) = argsMapping(aggIndex)(localIndex) - new ResolvedAggInputReference(input.getName, inputIndex, inputType) - case localIndex => - val variableName = s"agg${aggIndex}_${input.getName}" - newLocalReference( - ctx, variableName, aggBufferTypes(aggIndex)(localIndex)) - } - } - - override def visit(other: Expression): Expression = { - other match { - case u : UnresolvedReferenceExpression => visitUnresolvedFieldReference(u) - case u : UnresolvedCallExpression => visitUnresolvedCallExpression(u) - case _ => other - } + override def toAggBufferExpr(name: String, localIndex: Int): ResolvedExpression = { + val variableName = s"agg${aggIndex}_$name" + newLocalReference( + ctx, variableName, aggBufferTypes(aggIndex)(localIndex)) } } @@ -333,7 +299,7 @@ object AggCodeGenHelper { case (agg: DeclarativeAggregateFunction, aggIndex: Int) => val idx = auxGrouping.length + aggIndex agg.aggBufferAttributes.map(_.accept( - ResolveReference(ctx, isMerge, agg, idx, argsMapping, aggBufferTypes))) + ResolveReference(ctx, builder, isMerge, agg, idx, argsMapping, aggBufferTypes))) case (_: AggregateFunction[_, _], aggIndex: Int) => val idx = auxGrouping.length + aggIndex val variableName = aggBufferNames(idx)(0) @@ -525,7 +491,7 @@ object AggCodeGenHelper { case (agg: DeclarativeAggregateFunction, aggIndex) => val idx = auxGrouping.length + aggIndex agg.getValueExpression.accept(ResolveReference( - ctx, isMerge, agg, idx, argsMapping, aggBufferTypes)) + ctx, builder, isMerge, agg, idx, argsMapping, aggBufferTypes)) case (agg: AggregateFunction[_, _], aggIndex) => val idx = auxGrouping.length + aggIndex (agg, idx) @@ -567,8 +533,8 @@ object AggCodeGenHelper { aggregates.zipWithIndex.flatMap { case (agg: DeclarativeAggregateFunction, aggIndex) => val idx = auxGrouping.length + aggIndex - agg.mergeExpressions.map( - _.accept(ResolveReference(ctx, isMerge = true, agg, idx, argsMapping, aggBufferTypes))) + agg.mergeExpressions.map(_.accept(ResolveReference( + ctx, builder, isMerge = true, agg, idx, argsMapping, aggBufferTypes))) case (agg: AggregateFunction[_, _], aggIndex) => val idx = auxGrouping.length + aggIndex Some(agg, idx) @@ -629,8 +595,8 @@ object AggCodeGenHelper { val aggCall = aggCallToAggFun._1 aggCallToAggFun._2 match { case agg: DeclarativeAggregateFunction => - agg.accumulateExpressions.map(_.accept( - ResolveReference(ctx, isMerge = false, agg, idx, argsMapping, aggBufferTypes))) + agg.accumulateExpressions.map(_.accept(ResolveReference( + ctx, builder, isMerge = false, agg, idx, argsMapping, aggBufferTypes))) .map(e => (e, aggCall)) case agg: AggregateFunction[_, _] => val idx = auxGrouping.length + aggIndex diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenHelper.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenHelper.scala index 2fdf7b7..dec5e56 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenHelper.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/HashAggCodeGenHelper.scala @@ -21,14 +21,13 @@ package org.apache.flink.table.planner.codegen.agg.batch import org.apache.flink.api.java.tuple.{Tuple2 => JTuple2} import org.apache.flink.metrics.Gauge import org.apache.flink.table.dataformat.{BaseRow, BinaryRow, GenericRow, JoinedRow} -import org.apache.flink.table.expressions.utils.ApiExpressionUtils -import org.apache.flink.table.expressions.{Expression, ExpressionVisitor, FieldReferenceExpression, TypeLiteralExpression, UnresolvedCallExpression, UnresolvedReferenceExpression, ValueLiteralExpression, _} +import org.apache.flink.table.expressions.{Expression, _} import org.apache.flink.table.functions.{AggregateFunction, UserDefinedFunction} import org.apache.flink.table.planner.codegen.CodeGenUtils.{binaryRowFieldSetAccess, binaryRowSetNull} import org.apache.flink.table.planner.codegen._ import org.apache.flink.table.planner.codegen.agg.batch.AggCodeGenHelper.buildAggregateArgsMapping import org.apache.flink.table.planner.codegen.sort.SortCodeGenerator -import org.apache.flink.table.planner.expressions.{ResolvedAggInputReference, RexNodeConverter} +import org.apache.flink.table.planner.expressions.{DeclarativeExpressionResolver, ResolvedAggInputReference, RexNodeConverter} import org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction import org.apache.flink.table.planner.plan.utils.SortUtil import org.apache.flink.table.runtime.generated.{NormalizedKeyComputer, RecordComparator} @@ -37,6 +36,7 @@ import org.apache.flink.table.runtime.operators.sort.BufferedKVExternalSorter import org.apache.flink.table.runtime.typeutils.BinaryRowSerializer import org.apache.flink.table.types.DataType import org.apache.flink.table.types.logical.{LogicalType, RowType} + import org.apache.calcite.rel.core.AggregateCall import org.apache.calcite.tools.RelBuilder @@ -292,8 +292,8 @@ object HashAggCodeGenHelper { val getAggValueExprs = aggregates.zipWithIndex.map { case (agg: DeclarativeAggregateFunction, aggIndex) => val idx = auxGrouping.length + aggIndex - agg.getValueExpression.accept( - ResolveReference(ctx, isMerge, bindRefOffset, agg, idx, argsMapping, aggBuffMapping)) + agg.getValueExpression.accept(ResolveReference( + ctx, builder, isMerge, bindRefOffset, agg, idx, argsMapping, aggBuffMapping)) }.map(_.accept(new RexNodeConverter(builder))).map(exprCodegen.generateExpression) val getValueExprs = getAuxGroupingExprs ++ getAggValueExprs @@ -327,58 +327,28 @@ object HashAggCodeGenHelper { */ private case class ResolveReference( ctx: CodeGeneratorContext, + relBuilder: RelBuilder, isMerge: Boolean, offset: Int, agg: DeclarativeAggregateFunction, aggIndex: Int, argsMapping: Array[Array[(Int, LogicalType)]], - aggBuffMapping: Array[Array[(Int, LogicalType)]]) extends ExpressionVisitor[Expression] { - - override def visit(call: CallExpression): Expression = ??? - - override def visit(valueLiteralExpression: ValueLiteralExpression): Expression = { - valueLiteralExpression - } - - override def visit(input: FieldReferenceExpression): Expression = { - input - } - - override def visit(typeLiteral: TypeLiteralExpression): Expression = { - typeLiteral - } + aggBuffMapping: Array[Array[(Int, LogicalType)]]) + extends DeclarativeExpressionResolver(relBuilder, agg, isMerge) { - private def visitUnresolvedCallExpression( - unresolvedCall: UnresolvedCallExpression): Expression = { - ApiExpressionUtils.unresolvedCall( - unresolvedCall.getFunctionDefinition, - unresolvedCall.getChildren.map(_.accept(this)): _*) + override def toMergeInputExpr(name: String, localIndex: Int): ResolvedExpression = { + val (inputIndex, inputType) = argsMapping(aggIndex)(localIndex) + new ResolvedAggInputReference(name, inputIndex, inputType) } - private def visitUnresolvedFieldReference( - input: UnresolvedReferenceExpression): Expression = { - agg.aggBufferAttributes.indexOf(input) match { - case -1 => - // We always use UnresolvedFieldReference to represent reference of input field. - // In non-merge case, the input is operand of the aggregate function. But in merge - // case, the input is aggregate buffers which sent by local aggregate. - val localIndex = - if (isMerge) agg.mergeOperands.indexOf(input) else agg.operands.indexOf(input) - val (inputIndex, inputType) = argsMapping(aggIndex)(localIndex) - new ResolvedAggInputReference(input.getName, inputIndex, inputType) - case localIndex => - val (aggBuffAttrIndex, aggBuffAttrType) = aggBuffMapping(aggIndex)(localIndex) - new ResolvedAggInputReference( - input.getName, offset + aggBuffAttrIndex, aggBuffAttrType) - } + override def toAccInputExpr(name: String, localIndex: Int): ResolvedExpression = { + val (inputIndex, inputType) = argsMapping(aggIndex)(localIndex) + new ResolvedAggInputReference(name, inputIndex, inputType) } - override def visit(other: Expression): Expression = { - other match { - case u : UnresolvedReferenceExpression => visitUnresolvedFieldReference(u) - case u : UnresolvedCallExpression => visitUnresolvedCallExpression(u) - case _ => other - } + override def toAggBufferExpr(name: String, localIndex: Int): ResolvedExpression = { + val (aggBuffAttrIndex, aggBuffAttrType) = aggBuffMapping(aggIndex)(localIndex) + new ResolvedAggInputReference(name, offset + aggBuffAttrIndex, aggBuffAttrType) } } @@ -407,7 +377,7 @@ object HashAggCodeGenHelper { val bindRefOffset = inputType.getFieldCount agg.mergeExpressions.map( _.accept(ResolveReference( - ctx, isMerge = true, bindRefOffset, agg, idx, argsMapping, aggBuffMapping))) + ctx, builder, isMerge = true, bindRefOffset, agg, idx, argsMapping, aggBuffMapping))) }.map(_.accept(new RexNodeConverter(builder))).map(exprCodegen.generateExpression) val aggBufferTypeWithoutAuxGrouping = if (auxGrouping.nonEmpty) { @@ -464,9 +434,8 @@ object HashAggCodeGenHelper { val aggCall = aggCallToAggFun._1 aggCallToAggFun._2 match { case agg: DeclarativeAggregateFunction => - agg.accumulateExpressions.map( - _.accept(ResolveReference( - ctx, isMerge = false, bindRefOffset, agg, idx, argsMapping, aggBuffMapping)) + agg.accumulateExpressions.map(_.accept(ResolveReference( + ctx, builder, isMerge = false, bindRefOffset, agg, idx, argsMapping, aggBuffMapping)) ).map(e => (e, aggCall)) } }.map { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/WindowCodeGenerator.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/WindowCodeGenerator.scala index 785ac82..8a2d753 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/WindowCodeGenerator.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/codegen/agg/batch/WindowCodeGenerator.scala @@ -34,7 +34,7 @@ import org.apache.flink.table.planner.codegen._ import org.apache.flink.table.planner.codegen.agg.batch.AggCodeGenHelper.{buildAggregateArgsMapping, genAggregateByFlatAggregateBuffer, genFlatAggBufferExprs, genInitFlatAggregateBuffer} import org.apache.flink.table.planner.codegen.agg.batch.WindowCodeGenerator.{asLong, isTimeIntervalLiteral} import org.apache.flink.table.planner.expressions.ExpressionBuilder._ -import org.apache.flink.table.planner.expressions.RexNodeConverter +import org.apache.flink.table.planner.expressions.{CallExpressionResolver, RexNodeConverter} import org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils.getAccumulatorTypeOfAggregateFunction import org.apache.flink.table.planner.plan.logical.{LogicalWindow, SlidingGroupWindow, TumblingGroupWindow} @@ -695,7 +695,7 @@ abstract class WindowCodeGenerator( plus(remainder, literal(slideSize)), remainder)), literal(index * slideSize)) - exprCodegen.generateExpression(expr.accept( + exprCodegen.generateExpression(new CallExpressionResolver(relBuilder).resolve(expr).accept( new RexNodeConverter(relBuilder.values(inputRowType)))) } diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/expressions/PlannerExpressionConverter.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/expressions/PlannerExpressionConverter.scala index a8b4c11..b6b6905 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/expressions/PlannerExpressionConverter.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/expressions/PlannerExpressionConverter.scala @@ -821,6 +821,14 @@ class PlannerExpressionConverter private extends ApiExpressionVisitor[PlannerExp other match { // already converted planner expressions will pass this visitor without modification case plannerExpression: PlannerExpression => plannerExpression + case aggInput: ResolvedAggInputReference => PlannerResolvedAggInputReference( + aggInput.getName, aggInput.getIndex, fromDataTypeToTypeInfo(aggInput.getOutputDataType)) + case aggLocal: ResolvedAggLocalReference => PlannerResolvedAggLocalReference( + aggLocal.getFieldTerm, + aggLocal.getNullTerm, + fromDataTypeToTypeInfo(aggLocal.getOutputDataType)) + case key: ResolvedDistinctKeyReference => PlannerResolvedDistinctKeyReference( + key.getName, fromDataTypeToTypeInfo(key.getOutputDataType)) case _ => throw new TableException("Unrecognized expression: " + other) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/expressions/fieldExpression.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/expressions/fieldExpression.scala index 89e71db..2a3c621 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/expressions/fieldExpression.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/expressions/fieldExpression.scala @@ -19,6 +19,7 @@ package org.apache.flink.table.planner.expressions import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.table.api._ +import org.apache.flink.table.expressions.ResolvedFieldReference import org.apache.flink.table.operations.QueryOperation import org.apache.flink.table.planner.calcite.FlinkRelBuilder.NamedWindowProperty import org.apache.flink.table.planner.calcite.FlinkTypeFactory @@ -228,3 +229,57 @@ case class StreamRecordTimestamp() extends LeafExpression { override private[flink] def resultType = Types.LONG } + +/** + * Normally we should use [[ResolvedFieldReference]] to represent an input field. + * [[ResolvedFieldReference]] uses name to locate the field, in aggregate case, we want to use + * field index. + */ +case class PlannerResolvedAggInputReference( + name: String, + index: Int, + resultType: TypeInformation[_]) extends Attribute { + + override def toString = s"'$name" + + override private[flink] def withName(newName: String): Attribute = { + if (newName == name) this + else PlannerResolvedAggInputReference(newName, index, resultType) + } +} + +/** + * Special reference which represent a local filed, such as aggregate buffers or constants. + * We are stored as class members, so the field can be referenced directly. + * We should use an unique name to locate the field. + */ +case class PlannerResolvedAggLocalReference( + name: String, + nullTerm: String, + resultType: TypeInformation[_]) + extends Attribute { + + override def toString = s"'$name" + + override private[flink] def withName(newName: String): Attribute = { + if (newName == name) this + else PlannerResolvedAggLocalReference(newName, nullTerm, resultType) + } +} + +/** + * Special reference which represent a distinct key input filed, + * [[ResolvedDistinctKeyReference]] uses name to locate the field. + */ +case class PlannerResolvedDistinctKeyReference( + name: String, + resultType: TypeInformation[_]) + extends Attribute { + + override def toString = s"'$name" + + override private[flink] def withName(newName: String): Attribute = { + if (newName == name) this + else PlannerResolvedDistinctKeyReference(newName, resultType) + } +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala index 929c242..ebd2863 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/AggregateUtil.scala @@ -30,9 +30,9 @@ import org.apache.flink.table.planner.calcite.FlinkRelBuilder.PlannerNamedWindow import org.apache.flink.table.planner.calcite.{FlinkTypeFactory, FlinkTypeSystem} import org.apache.flink.table.planner.dataview.DataViewUtils.useNullSerializerForStateViewFieldsFromAccType import org.apache.flink.table.planner.dataview.{DataViewSpec, MapViewSpec} -import org.apache.flink.table.planner.expressions.{PlannerProctimeAttribute, PlannerRowtimeAttribute, PlannerWindowEnd, PlannerWindowStart, RexNodeConverter} +import org.apache.flink.table.planner.expressions.{PlannerProctimeAttribute, PlannerRowtimeAttribute, PlannerWindowEnd, PlannerWindowStart} import org.apache.flink.table.planner.functions.aggfunctions.DeclarativeAggregateFunction -import org.apache.flink.table.planner.functions.sql.{FlinkSqlOperatorTable, SqlListAggFunction, SqlFirstLastValueAggFunction} +import org.apache.flink.table.planner.functions.sql.{FlinkSqlOperatorTable, SqlFirstLastValueAggFunction, SqlListAggFunction} import org.apache.flink.table.planner.functions.utils.AggSqlFunction import org.apache.flink.table.planner.functions.utils.UserDefinedFunctionUtils._ import org.apache.flink.table.planner.plan.`trait`.RelModifiedMonotonicity @@ -49,7 +49,6 @@ import org.apache.flink.table.types.utils.TypeConversions.fromLegacyInfoToDataTy import org.apache.calcite.rel.`type`._ import org.apache.calcite.rel.core.{Aggregate, AggregateCall} -import org.apache.calcite.rex.RexInputRef import org.apache.calcite.sql.fun._ import org.apache.calcite.sql.validate.SqlMonotonicity import org.apache.calcite.sql.{SqlKind, SqlRankFunction} @@ -687,8 +686,7 @@ object AggregateUtil extends Enumeration { */ def timeFieldIndex( inputType: RelDataType, relBuilder: RelBuilder, timeField: FieldReferenceExpression): Int = { - timeField.accept(new RexNodeConverter(relBuilder.values(inputType))) - .asInstanceOf[RexInputRef].getIndex + relBuilder.values(inputType).field(timeField.getName).getIndex } /**