Repository: flink
Updated Branches:
  refs/heads/tableOnCalcite e34e43954 -> d720b002a


http://git-wip-us.apache.org/repos/asf/flink/blob/22621e02/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateUtil.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateUtil.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateUtil.scala
index 1b876da..11857df 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateUtil.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/api/table/runtime/aggregate/AggregateUtil.scala
@@ -27,7 +27,8 @@ import org.apache.calcite.sql.`type`.{SqlTypeFactoryImpl, 
SqlTypeName}
 import org.apache.calcite.sql.fun._
 import org.apache.flink.api.common.functions.{GroupReduceFunction, MapFunction}
 import org.apache.flink.api.common.typeinfo.TypeInformation
-import org.apache.flink.api.table.plan.PlanGenException
+import org.apache.flink.api.table.plan.{TypeConverter, PlanGenException}
+import org.apache.flink.api.table.plan.TypeConverter._
 import org.apache.flink.api.table.typeinfo.RowTypeInfo
 import org.apache.flink.api.table.{Row, TableConfig}
 
@@ -64,7 +65,8 @@ object AggregateUtil {
    */
   def createOperatorFunctionsForAggregates(namedAggregates: 
Seq[CalcitePair[AggregateCall, String]],
       inputType: RelDataType, outputType: RelDataType,
-      groupings: Array[Int]): AggregateResult = {
+      groupings: Array[Int],
+      config: TableConfig): AggregateResult = {
 
     val aggregateFunctionsAndFieldIndexes =
       transformToAggregateFunctions(namedAggregates.map(_.getKey), inputType, 
groupings.length)
@@ -72,20 +74,19 @@ object AggregateUtil {
     val aggFieldIndexes = aggregateFunctionsAndFieldIndexes._1
     val aggregates = aggregateFunctionsAndFieldIndexes._2
 
-    val mapFunction = (
-        config: TableConfig,
-        inputType: TypeInformation[Any],
-        returnType: TypeInformation[Any]) => {
-
-      val aggregateMapFunction = new AggregateMapFunction[Row, Row](
-        aggregates, aggFieldIndexes, groupings, 
returnType.asInstanceOf[RowTypeInfo])
-
-      aggregateMapFunction.asInstanceOf[MapFunction[Any, Any]]
-    }
-
     val bufferDataType: RelRecordType =
       createAggregateBufferDataType(groupings, aggregates, inputType)
 
+    val mapReturnType = determineReturnType(
+        bufferDataType,
+        Some(TypeConverter.DEFAULT_ROW_TYPE),
+        config.getNullCheck,
+        config.getEfficientTypeUsage)
+
+    val mapFunction = new AggregateMapFunction[Row, Row](
+        aggregates, aggFieldIndexes, groupings,
+        mapReturnType.asInstanceOf[RowTypeInfo]).asInstanceOf[MapFunction[Any, 
Any]]
+
     // the mapping relation between field index of intermediate aggregate Row 
and output Row.
     val groupingOffsetMapping = getGroupKeysMapping(inputType, outputType, 
groupings)
 
@@ -105,16 +106,15 @@ object AggregateUtil {
 
     val reduceGroupFunction =
       if (allPartialAggregate) {
-        (config: TableConfig, inputType: TypeInformation[Row], returnType: 
TypeInformation[Row]) =>
-          new AggregateReduceCombineFunction(aggregates, groupingOffsetMapping,
-            aggOffsetMapping, intermediateRowArity)
-      } else {
-        (config: TableConfig, inputType: TypeInformation[Row], returnType: 
TypeInformation[Row]) =>
-          new AggregateReduceGroupFunction(aggregates, groupingOffsetMapping,
-            aggOffsetMapping, intermediateRowArity)
+        new AggregateReduceCombineFunction(aggregates, groupingOffsetMapping,
+          aggOffsetMapping, intermediateRowArity)
+      }
+      else {
+        new AggregateReduceGroupFunction(aggregates, groupingOffsetMapping,
+          aggOffsetMapping, intermediateRowArity)
       }
 
-    new AggregateResult(mapFunction, reduceGroupFunction, bufferDataType)
+    new AggregateResult(mapFunction, reduceGroupFunction)
   }
 
   private def transformToAggregateFunctions(
@@ -318,9 +318,6 @@ object AggregateUtil {
 }
 
 case class AggregateResult(
-    val mapFunc: (TableConfig, TypeInformation[Any], TypeInformation[Any]) =>
-        MapFunction[Any, Any],
-    val reduceGroupFunc: (TableConfig, TypeInformation[Row], 
TypeInformation[Row]) =>
-        GroupReduceFunction[Row, Row],
-    val intermediateDataType: RelDataType) {
+    val mapFunc: MapFunction[Any, Any],
+    val reduceGroupFunc: GroupReduceFunction[Row, Row]) {
 }

Reply via email to