Repository: flink Updated Branches: refs/heads/tableOnCalcite e34e43954 -> d720b002a
http://git-wip-us.apache.org/repos/asf/flink/blob/22621e02/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateUtil.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateUtil.scala index 1b876da..11857df 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateUtil.scala @@ -27,7 +27,8 @@ import org.apache.calcite.sql.`type`.{SqlTypeFactoryImpl, SqlTypeName} import org.apache.calcite.sql.fun._ import org.apache.flink.api.common.functions.{GroupReduceFunction, MapFunction} import org.apache.flink.api.common.typeinfo.TypeInformation -import org.apache.flink.api.table.plan.PlanGenException +import org.apache.flink.api.table.plan.{TypeConverter, PlanGenException} +import org.apache.flink.api.table.plan.TypeConverter._ import org.apache.flink.api.table.typeinfo.RowTypeInfo import org.apache.flink.api.table.{Row, TableConfig} @@ -64,7 +65,8 @@ object AggregateUtil { */ def createOperatorFunctionsForAggregates(namedAggregates: Seq[CalcitePair[AggregateCall, String]], inputType: RelDataType, outputType: RelDataType, - groupings: Array[Int]): AggregateResult = { + groupings: Array[Int], + config: TableConfig): AggregateResult = { val aggregateFunctionsAndFieldIndexes = transformToAggregateFunctions(namedAggregates.map(_.getKey), inputType, groupings.length) @@ -72,20 +74,19 @@ object AggregateUtil { val aggFieldIndexes = aggregateFunctionsAndFieldIndexes._1 val aggregates = aggregateFunctionsAndFieldIndexes._2 - val mapFunction = ( - config: TableConfig, - inputType: TypeInformation[Any], - returnType: TypeInformation[Any]) => { - - val aggregateMapFunction = new AggregateMapFunction[Row, Row]( - aggregates, aggFieldIndexes, groupings, returnType.asInstanceOf[RowTypeInfo]) - - aggregateMapFunction.asInstanceOf[MapFunction[Any, Any]] - } - val bufferDataType: RelRecordType = createAggregateBufferDataType(groupings, aggregates, inputType) + val mapReturnType = determineReturnType( + bufferDataType, + Some(TypeConverter.DEFAULT_ROW_TYPE), + config.getNullCheck, + config.getEfficientTypeUsage) + + val mapFunction = new AggregateMapFunction[Row, Row]( + aggregates, aggFieldIndexes, groupings, + mapReturnType.asInstanceOf[RowTypeInfo]).asInstanceOf[MapFunction[Any, Any]] + // the mapping relation between field index of intermediate aggregate Row and output Row. val groupingOffsetMapping = getGroupKeysMapping(inputType, outputType, groupings) @@ -105,16 +106,15 @@ object AggregateUtil { val reduceGroupFunction = if (allPartialAggregate) { - (config: TableConfig, inputType: TypeInformation[Row], returnType: TypeInformation[Row]) => - new AggregateReduceCombineFunction(aggregates, groupingOffsetMapping, - aggOffsetMapping, intermediateRowArity) - } else { - (config: TableConfig, inputType: TypeInformation[Row], returnType: TypeInformation[Row]) => - new AggregateReduceGroupFunction(aggregates, groupingOffsetMapping, - aggOffsetMapping, intermediateRowArity) + new AggregateReduceCombineFunction(aggregates, groupingOffsetMapping, + aggOffsetMapping, intermediateRowArity) + } + else { + new AggregateReduceGroupFunction(aggregates, groupingOffsetMapping, + aggOffsetMapping, intermediateRowArity) } - new AggregateResult(mapFunction, reduceGroupFunction, bufferDataType) + new AggregateResult(mapFunction, reduceGroupFunction) } private def transformToAggregateFunctions( @@ -318,9 +318,6 @@ object AggregateUtil { } case class AggregateResult( - val mapFunc: (TableConfig, TypeInformation[Any], TypeInformation[Any]) => - MapFunction[Any, Any], - val reduceGroupFunc: (TableConfig, TypeInformation[Row], TypeInformation[Row]) => - GroupReduceFunction[Row, Row], - val intermediateDataType: RelDataType) { + val mapFunc: MapFunction[Any, Any], + val reduceGroupFunc: GroupReduceFunction[Row, Row]) { }
