Repository: flink
Updated Branches:
  refs/heads/master f3cd9c059 -> b31b707cb


[FLINK-8821] [table] Fix non-terminating decimal error

This closes #5608.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/b31b707c
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/b31b707c
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/b31b707c

Branch: refs/heads/master
Commit: b31b707cb20f34633815718ff356e187f3397620
Parents: f3cd9c0
Author: Xpray <leonxp...@gmail.com>
Authored: Fri Mar 2 12:11:45 2018 +0800
Committer: Timo Walther <twal...@apache.org>
Committed: Fri Mar 2 09:18:46 2018 +0100

----------------------------------------------------------------------
 .../apache/flink/table/api/TableConfig.scala    | 21 ++++++++
 .../flink/table/codegen/CodeGenerator.scala     | 18 +++----
 .../table/codegen/calls/ScalarOperators.scala   | 29 ++++++++--
 .../functions/aggfunctions/AvgAggFunction.scala |  7 +--
 .../plan/nodes/dataset/DataSetAggregate.scala   |  3 +-
 .../nodes/dataset/DataSetWindowAggregate.scala  | 55 ++++++++++++-------
 .../datastream/DataStreamGroupAggregate.scala   |  1 +
 .../DataStreamGroupWindowAggregate.scala        |  6 ++-
 .../datastream/DataStreamOverAggregate.scala    | 12 +++--
 .../table/runtime/aggregate/AggregateUtil.scala | 57 ++++++++++++++------
 .../table/expressions/DecimalTypeTest.scala     | 11 +++-
 .../runtime/aggfunctions/AvgFunctionTest.scala  | 12 +++--
 .../table/runtime/batch/table/CalcITCase.scala  |  8 ++-
 13 files changed, 176 insertions(+), 64 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/b31b707c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableConfig.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableConfig.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableConfig.scala
index 6448657..c78a022 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableConfig.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableConfig.scala
@@ -18,6 +18,7 @@
 package org.apache.flink.table.api
 
 import _root_.java.util.TimeZone
+import _root_.java.math.MathContext
 
 import org.apache.flink.table.calcite.CalciteConfig
 
@@ -42,6 +43,12 @@ class TableConfig {
   private var calciteConfig = CalciteConfig.DEFAULT
 
   /**
+    * Defines the default context for decimal division calculation.
+    * We use Scala's default MathContext.DECIMAL128.
+    */
+  private var decimalContext = MathContext.DECIMAL128
+
+  /**
    * Sets the timezone for date/time/timestamp conversions.
    */
   def setTimeZone(timeZone: TimeZone): Unit = {
@@ -78,6 +85,20 @@ class TableConfig {
   def setCalciteConfig(calciteConfig: CalciteConfig): Unit = {
     this.calciteConfig = calciteConfig
   }
+
+  /**
+    * Returns the default context for decimal division calculation.
+    * [[_root_.java.math.MathContext#DECIMAL128]] by default.
+    */
+  def getDecimalContext: MathContext = decimalContext
+
+  /**
+    * Sets the default context for decimal division calculation.
+    * [[_root_.java.math.MathContext#DECIMAL128]] by default.
+    */
+  def setDecimalContext(mathContext: MathContext): Unit = {
+    this.decimalContext = mathContext
+  }
 }
 
 object TableConfig {

http://git-wip-us.apache.org/repos/asf/flink/blob/b31b707c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
index 756a828..e4064d6 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala
@@ -742,56 +742,56 @@ abstract class CodeGenerator(
         val right = operands(1)
         requireNumeric(left)
         requireNumeric(right)
-        generateArithmeticOperator("+", nullCheck, resultType, left, right)
+        generateArithmeticOperator("+", nullCheck, resultType, left, right, 
config)
 
       case PLUS | DATETIME_PLUS if isTemporal(resultType) =>
         val left = operands.head
         val right = operands(1)
         requireTemporal(left)
         requireTemporal(right)
-        generateTemporalPlusMinus(plus = true, nullCheck, left, right)
+        generateTemporalPlusMinus(plus = true, nullCheck, left, right, config)
 
       case MINUS if isNumeric(resultType) =>
         val left = operands.head
         val right = operands(1)
         requireNumeric(left)
         requireNumeric(right)
-        generateArithmeticOperator("-", nullCheck, resultType, left, right)
+        generateArithmeticOperator("-", nullCheck, resultType, left, right, 
config)
 
       case MINUS | MINUS_DATE if isTemporal(resultType) =>
         val left = operands.head
         val right = operands(1)
         requireTemporal(left)
         requireTemporal(right)
-        generateTemporalPlusMinus(plus = false, nullCheck, left, right)
+        generateTemporalPlusMinus(plus = false, nullCheck, left, right, config)
 
       case MULTIPLY if isNumeric(resultType) =>
         val left = operands.head
         val right = operands(1)
         requireNumeric(left)
         requireNumeric(right)
-        generateArithmeticOperator("*", nullCheck, resultType, left, right)
+        generateArithmeticOperator("*", nullCheck, resultType, left, right, 
config)
 
       case MULTIPLY if isTimeInterval(resultType) =>
         val left = operands.head
         val right = operands(1)
         requireTimeInterval(left)
         requireNumeric(right)
-        generateArithmeticOperator("*", nullCheck, resultType, left, right)
+        generateArithmeticOperator("*", nullCheck, resultType, left, right, 
config)
 
       case DIVIDE | DIVIDE_INTEGER if isNumeric(resultType) =>
         val left = operands.head
         val right = operands(1)
         requireNumeric(left)
         requireNumeric(right)
-        generateArithmeticOperator("/", nullCheck, resultType, left, right)
+        generateArithmeticOperator("/", nullCheck, resultType, left, right, 
config)
 
       case MOD if isNumeric(resultType) =>
         val left = operands.head
         val right = operands(1)
         requireNumeric(left)
         requireNumeric(right)
-        generateArithmeticOperator("%", nullCheck, resultType, left, right)
+        generateArithmeticOperator("%", nullCheck, resultType, left, right, 
config)
 
       case UNARY_MINUS if isNumeric(resultType) =>
         val operand = operands.head
@@ -922,7 +922,7 @@ abstract class CodeGenerator(
         val left = operands.head
         val right = operands(1)
         requireString(left)
-        generateArithmeticOperator("+", nullCheck, resultType, left, right)
+        generateArithmeticOperator("+", nullCheck, resultType, left, right, 
config)
 
       // rows
       case ROW =>

http://git-wip-us.apache.org/repos/asf/flink/blob/b31b707c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala
index a261b3d..57f1618 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala
@@ -17,12 +17,15 @@
  */
 package org.apache.flink.table.codegen.calls
 
+import java.math.MathContext
+
 import org.apache.calcite.avatica.util.DateTimeUtils.MILLIS_PER_DAY
 import org.apache.calcite.avatica.util.{DateTimeUtils, TimeUnitRange}
 import org.apache.calcite.util.BuiltInMethod
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
 import org.apache.flink.api.common.typeinfo._
 import org.apache.flink.api.java.typeutils.{MapTypeInfo, ObjectArrayTypeInfo, 
RowTypeInfo}
+import org.apache.flink.table.api.TableConfig
 import org.apache.flink.table.codegen.CodeGenUtils._
 import 
org.apache.flink.table.codegen.calls.CallGenerator.generateCallIfArgsNotNull
 import org.apache.flink.table.codegen.{CodeGenException, CodeGenerator, 
GeneratedExpression}
@@ -46,7 +49,8 @@ object ScalarOperators {
       nullCheck: Boolean,
       resultType: TypeInformation[_],
       left: GeneratedExpression,
-      right: GeneratedExpression): GeneratedExpression = {
+      right: GeneratedExpression,
+      config: TableConfig): GeneratedExpression = {
 
     val leftCasting = operator match {
       case "%" =>
@@ -68,7 +72,15 @@ object ScalarOperators {
     generateOperatorIfNotNull(nullCheck, resultType, left, right) {
       (leftTerm, rightTerm) =>
         if (isDecimal(resultType)) {
-          
s"${leftCasting(leftTerm)}.${arithOpToDecMethod(operator)}(${rightCasting(rightTerm)})"
+          val decMethod = arithOpToDecMethod(operator)
+          operator match {
+            // include math context for decimal division
+            case "/" =>
+              val mathContext = mathContextToString(config.getDecimalContext)
+              
s"${leftCasting(leftTerm)}.$decMethod(${rightCasting(rightTerm)}, $mathContext)"
+            case _ =>
+              
s"${leftCasting(leftTerm)}.$decMethod(${rightCasting(rightTerm)})"
+          }
         } else {
           s"($resultTypeTerm) (${leftCasting(leftTerm)} $operator 
${rightCasting(rightTerm)})"
         }
@@ -814,14 +826,15 @@ object ScalarOperators {
       plus: Boolean,
       nullCheck: Boolean,
       left: GeneratedExpression,
-      right: GeneratedExpression)
+      right: GeneratedExpression,
+      config: TableConfig)
     : GeneratedExpression = {
 
     val op = if (plus) "+" else "-"
 
     (left.resultType, right.resultType) match {
       case (l: TimeIntervalTypeInfo[_], r: TimeIntervalTypeInfo[_]) if l == r 
=>
-        generateArithmeticOperator(op, nullCheck, l, left, right)
+        generateArithmeticOperator(op, nullCheck, l, left, right, config)
 
       case (SqlTimeTypeInfo.DATE, TimeIntervalTypeInfo.INTERVAL_MILLIS) =>
         generateOperatorIfNotNull(nullCheck, SqlTimeTypeInfo.DATE, left, 
right) {
@@ -1290,6 +1303,14 @@ object ScalarOperators {
     case _ => throw new CodeGenException(s"Unsupported decimal arithmetic 
operator: '$operator'")
   }
 
+  private def mathContextToString(mathContext: MathContext): String = 
mathContext match {
+    case MathContext.DECIMAL32 => "java.math.MathContext.DECIMAL32"
+    case MathContext.DECIMAL64 => "java.math.MathContext.DECIMAL64"
+    case MathContext.DECIMAL128 => "java.math.MathContext.DECIMAL128"
+    case MathContext.UNLIMITED => "java.math.MathContext.UNLIMITED"
+    case _ => s"""new java.math.MathContext("$mathContext")"""
+  }
+
   private def numericCasting(
       operandType: TypeInformation[_],
       resultType: TypeInformation[_])

http://git-wip-us.apache.org/repos/asf/flink/blob/b31b707c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala
index b651c42..26621b7 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala
@@ -17,7 +17,7 @@
  */
 package org.apache.flink.table.functions.aggfunctions
 
-import java.math.{BigDecimal, BigInteger}
+import java.math.{BigDecimal, BigInteger, MathContext}
 import java.lang.{Iterable => JIterable}
 
 import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
@@ -295,7 +295,8 @@ class DecimalAvgAccumulator extends JTuple2[BigDecimal, 
Long] {
 /**
   * Base class for built-in Big Decimal Avg aggregate function
   */
-class DecimalAvgAggFunction extends AggregateFunction[BigDecimal, 
DecimalAvgAccumulator] {
+class DecimalAvgAggFunction(context: MathContext)
+  extends AggregateFunction[BigDecimal, DecimalAvgAccumulator] {
 
   override def createAccumulator(): DecimalAvgAccumulator = {
     new DecimalAvgAccumulator
@@ -321,7 +322,7 @@ class DecimalAvgAggFunction extends 
AggregateFunction[BigDecimal, DecimalAvgAccu
     if (acc.f1 == 0) {
       null.asInstanceOf[BigDecimal]
     } else {
-      acc.f0.divide(BigDecimal.valueOf(acc.f1))
+      acc.f0.divide(BigDecimal.valueOf(acc.f1), context)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/b31b707c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala
index 7dd307b..07dcf79 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala
@@ -110,7 +110,8 @@ class DataSetAggregate(
         input.getRowType,
         inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
         rowRelDataType,
-        grouping)
+        grouping,
+        tableEnv.getConfig)
 
     val aggString = aggregationToString(inputType, grouping, getRowType, 
namedAggregates, Nil)
 

http://git-wip-us.apache.org/repos/asf/flink/blob/b31b707c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala
index 745c4ed..53748f5 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala
@@ -25,7 +25,7 @@ import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
 import org.apache.flink.api.common.operators.Order
 import org.apache.flink.api.java.DataSet
 import org.apache.flink.api.java.typeutils.{ResultTypeQueryable, RowTypeInfo}
-import org.apache.flink.table.api.{BatchQueryConfig, BatchTableEnvironment}
+import org.apache.flink.table.api.{BatchQueryConfig, BatchTableEnvironment, 
TableConfig}
 import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
 import org.apache.flink.table.calcite.FlinkTypeFactory
 import org.apache.flink.table.codegen.AggregationCodeGenerator
@@ -127,11 +127,12 @@ class DataSetWindowAggregate(
           generator,
           inputDS,
           isTimeIntervalLiteral(size),
-          caseSensitive)
+          caseSensitive,
+          tableEnv.getConfig)
 
       case SessionGroupWindow(_, timeField, gap)
           if isTimePoint(timeField.resultType) || isLong(timeField.resultType) 
=>
-        createEventTimeSessionWindowDataSet(generator, inputDS, caseSensitive)
+        createEventTimeSessionWindowDataSet(generator, inputDS, caseSensitive, 
tableEnv.getConfig)
 
       case SlidingGroupWindow(_, timeField, size, slide)
           if isTimePoint(timeField.resultType) || isLong(timeField.resultType) 
=>
@@ -141,7 +142,8 @@ class DataSetWindowAggregate(
           isTimeIntervalLiteral(size),
           asLong(size),
           asLong(slide),
-          caseSensitive)
+          caseSensitive,
+          tableEnv.getConfig)
 
       case _ =>
         throw new UnsupportedOperationException(
@@ -153,7 +155,8 @@ class DataSetWindowAggregate(
       generator: AggregationCodeGenerator,
       inputDS: DataSet[Row],
       isTimeWindow: Boolean,
-      isParserCaseSensitive: Boolean): DataSet[Row] = {
+      isParserCaseSensitive: Boolean,
+      tableConfig: TableConfig): DataSet[Row] = {
 
     val input = inputNode.asInstanceOf[DataSetRel]
 
@@ -164,7 +167,8 @@ class DataSetWindowAggregate(
       grouping,
       input.getRowType,
       inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
-      isParserCaseSensitive)
+      isParserCaseSensitive,
+      tableConfig)
     val groupReduceFunction = 
createDataSetWindowAggregationGroupReduceFunction(
       generator,
       window,
@@ -173,7 +177,8 @@ class DataSetWindowAggregate(
       inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
       getRowType,
       grouping,
-      namedProperties)
+      namedProperties,
+      tableConfig)
 
     val mappedInput = inputDS
       .map(mapFunction)
@@ -215,7 +220,8 @@ class DataSetWindowAggregate(
   private[this] def createEventTimeSessionWindowDataSet(
       generator: AggregationCodeGenerator,
       inputDS: DataSet[Row],
-      isParserCaseSensitive: Boolean): DataSet[Row] = {
+      isParserCaseSensitive: Boolean,
+      tableConfig: TableConfig): DataSet[Row] = {
 
     val input = inputNode.asInstanceOf[DataSetRel]
 
@@ -230,7 +236,8 @@ class DataSetWindowAggregate(
       grouping,
       input.getRowType,
       inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
-      isParserCaseSensitive)
+      isParserCaseSensitive,
+      tableConfig)
 
     val mappedInput = inputDS.map(mapFunction).name(prepareOperatorName)
 
@@ -243,7 +250,8 @@ class DataSetWindowAggregate(
     if (doAllSupportPartialMerge(
       namedAggregates.map(_.getKey),
       inputType,
-      grouping.length)) {
+      grouping.length,
+      tableConfig)) {
 
       // gets the window-start and window-end position  in the intermediate 
result.
       val windowStartPos = rowTimeFieldPos
@@ -257,7 +265,8 @@ class DataSetWindowAggregate(
           namedAggregates,
           input.getRowType,
           inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
-          grouping)
+          grouping,
+          tableConfig)
 
         // create groupReduceFunction for calculating the aggregations
         val groupReduceFunction = 
createDataSetWindowAggregationGroupReduceFunction(
@@ -269,6 +278,7 @@ class DataSetWindowAggregate(
           rowRelDataType,
           grouping,
           namedProperties,
+          tableConfig,
           isInputCombined = true)
 
         mappedInput
@@ -289,7 +299,8 @@ class DataSetWindowAggregate(
           namedAggregates,
           input.getRowType,
           inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
-          grouping)
+          grouping,
+          tableConfig)
 
         // create groupReduceFunction for calculating the aggregations
         val groupReduceFunction = 
createDataSetWindowAggregationGroupReduceFunction(
@@ -301,6 +312,7 @@ class DataSetWindowAggregate(
           rowRelDataType,
           grouping,
           namedProperties,
+          tableConfig,
           isInputCombined = true)
 
         mappedInput.sortPartition(rowTimeFieldPos, Order.ASCENDING)
@@ -326,7 +338,8 @@ class DataSetWindowAggregate(
           inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
           rowRelDataType,
           grouping,
-          namedProperties)
+          namedProperties,
+          tableConfig)
 
         mappedInput.groupBy(groupingKeys: _*)
           .sortGroup(rowTimeFieldPos, Order.ASCENDING)
@@ -343,7 +356,8 @@ class DataSetWindowAggregate(
           inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
           rowRelDataType,
           grouping,
-          namedProperties)
+          namedProperties,
+          tableConfig)
 
         mappedInput.sortPartition(rowTimeFieldPos, 
Order.ASCENDING).setParallelism(1)
           .reduceGroup(groupReduceFunction)
@@ -360,7 +374,8 @@ class DataSetWindowAggregate(
       isTimeWindow: Boolean,
       size: Long,
       slide: Long,
-      isParserCaseSensitive: Boolean)
+      isParserCaseSensitive: Boolean,
+      tableConfig: TableConfig)
     : DataSet[Row] = {
 
     val input = inputNode.asInstanceOf[DataSetRel]
@@ -374,7 +389,8 @@ class DataSetWindowAggregate(
       grouping,
       input.getRowType,
       inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
-      isParserCaseSensitive)
+      isParserCaseSensitive,
+      tableConfig)
 
     val mappedDataSet = inputDS
       .map(mapFunction)
@@ -389,7 +405,8 @@ class DataSetWindowAggregate(
     val isPartial = doAllSupportPartialMerge(
       namedAggregates.map(_.getKey),
       inputType,
-      grouping.length)
+      grouping.length,
+      tableConfig)
 
     // only pre-tumble if it is worth it
     val isLittleTumblingSize = determineLargestTumblingSize(size, slide) <= 1
@@ -411,7 +428,8 @@ class DataSetWindowAggregate(
           grouping,
           input.getRowType,
           inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes,
-          isParserCaseSensitive)
+          isParserCaseSensitive,
+          tableConfig)
 
         mappedDataSet.asInstanceOf[DataSet[Row]]
           .groupBy(groupingKeysAndAlignedRowtime: _*)
@@ -451,6 +469,7 @@ class DataSetWindowAggregate(
       rowRelDataType,
       grouping,
       namedProperties,
+      tableConfig,
       isInputCombined = false)
 
     // gets the window-start position in the intermediate result.

http://git-wip-us.apache.org/repos/asf/flink/blob/b31b707c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala
index 71de57c..5f4b186 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala
@@ -138,6 +138,7 @@ class DataStreamGroupAggregate(
       inputSchema.fieldTypeInfos,
       groupings,
       queryConfig,
+      tableEnv.getConfig,
       DataStreamRetractionRules.isAccRetract(this),
       DataStreamRetractionRules.isAccRetract(getInput))
 

http://git-wip-us.apache.org/repos/asf/flink/blob/b31b707c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala
index d527dc8..0a014b6 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala
@@ -207,7 +207,8 @@ class DataStreamGroupWindowAggregate(
           inputSchema.fieldTypeInfos,
           schema.relDataType,
           grouping,
-          needMerge)
+          needMerge,
+          tableEnv.getConfig)
 
       windowedStream
         .aggregate(aggFunction, windowFunction, accumulatorRowType, 
aggResultRowType, outRowType)
@@ -232,7 +233,8 @@ class DataStreamGroupWindowAggregate(
           inputSchema.fieldTypeInfos,
           schema.relDataType,
           Array[Int](),
-          needMerge)
+          needMerge,
+          tableEnv.getConfig)
 
       windowedStream
         .aggregate(aggFunction, windowFunction, accumulatorRowType, 
aggResultRowType, outRowType)

http://git-wip-us.apache.org/repos/asf/flink/blob/b31b707c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala
index 635c7bc..c1693d9 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala
@@ -28,7 +28,7 @@ import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel}
 import org.apache.calcite.rex.RexLiteral
 import org.apache.flink.api.java.functions.NullByteKeySelector
 import org.apache.flink.streaming.api.datastream.DataStream
-import org.apache.flink.table.api.{StreamQueryConfig, StreamTableEnvironment, 
TableException}
+import org.apache.flink.table.api.{StreamQueryConfig, StreamTableEnvironment, 
TableConfig, TableException}
 import org.apache.flink.table.calcite.FlinkTypeFactory
 import org.apache.flink.table.codegen.AggregationCodeGenerator
 import org.apache.flink.table.plan.nodes.OverAggregate
@@ -172,6 +172,7 @@ class DataStreamOverAggregate(
       // unbounded OVER window
       createUnboundedAndCurrentRowOverWindow(
         queryConfig,
+        tableEnv.getConfig,
         generator,
         inputDS,
         rowTimeIdx,
@@ -188,7 +189,8 @@ class DataStreamOverAggregate(
         inputDS,
         rowTimeIdx,
         aggregateInputType,
-        isRowsClause = overWindow.isRows)
+        isRowsClause = overWindow.isRows,
+        tableEnv.getConfig)
     } else {
       throw new TableException("OVER RANGE FOLLOWING windows are not supported 
yet.")
     }
@@ -196,6 +198,7 @@ class DataStreamOverAggregate(
 
   def createUnboundedAndCurrentRowOverWindow(
     queryConfig: StreamQueryConfig,
+    tableConfig: TableConfig,
     generator: AggregationCodeGenerator,
     inputDS: DataStream[CRow],
     rowTimeIdx: Option[Int],
@@ -219,6 +222,7 @@ class DataStreamOverAggregate(
       inputSchema.typeInfo,
       inputSchema.fieldTypeInfos,
       queryConfig,
+      tableConfig,
       rowTimeIdx,
       partitionKeys.nonEmpty,
       isRowsClause)
@@ -249,7 +253,8 @@ class DataStreamOverAggregate(
     inputDS: DataStream[CRow],
     rowTimeIdx: Option[Int],
     aggregateInputType: RelDataType,
-    isRowsClause: Boolean): DataStream[CRow] = {
+    isRowsClause: Boolean,
+    tableConfig: TableConfig): DataStream[CRow] = {
 
     val overWindow: Group = logicWindow.groups.get(0)
 
@@ -272,6 +277,7 @@ class DataStreamOverAggregate(
       inputSchema.fieldTypeInfos,
       precedingOffset,
       queryConfig,
+      tableConfig,
       isRowsClause,
       rowTimeIdx
     )

http://git-wip-us.apache.org/repos/asf/flink/blob/b31b707c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
index 361a87e..df9b1c5 100644
--- 
a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
+++ 
b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala
@@ -32,7 +32,7 @@ import 
org.apache.flink.streaming.api.functions.ProcessFunction
 import org.apache.flink.streaming.api.functions.windowing.{AllWindowFunction, 
WindowFunction}
 import org.apache.flink.streaming.api.windowing.windows.{Window => 
DataStreamWindow}
 import org.apache.flink.table.api.dataview.DataViewSpec
-import org.apache.flink.table.api.{StreamQueryConfig, TableException}
+import org.apache.flink.table.api.{StreamQueryConfig, TableConfig, 
TableException}
 import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty
 import org.apache.flink.table.calcite.FlinkTypeFactory
 import org.apache.flink.table.codegen.AggregationCodeGenerator
@@ -78,6 +78,7 @@ object AggregateUtil {
       inputTypeInfo: TypeInformation[Row],
       inputFieldTypeInfo: Seq[TypeInformation[_]],
       queryConfig: StreamQueryConfig,
+      tableConfig: TableConfig,
       rowTimeIdx: Option[Int],
       isPartitioned: Boolean,
       isRowsClause: Boolean)
@@ -88,6 +89,7 @@ object AggregateUtil {
         namedAggregates.map(_.getKey),
         aggregateInputType,
         needRetraction = false,
+        tableConfig,
         isStateBackedDataViews = true)
 
     val aggregationStateType: RowTypeInfo = new RowTypeInfo(accTypes: _*)
@@ -159,6 +161,7 @@ object AggregateUtil {
       inputFieldTypes: Seq[TypeInformation[_]],
       groupings: Array[Int],
       queryConfig: StreamQueryConfig,
+      tableConfig: TableConfig,
       generateRetraction: Boolean,
       consumeRetraction: Boolean): ProcessFunction[CRow, CRow] = {
 
@@ -167,6 +170,7 @@ object AggregateUtil {
         namedAggregates.map(_.getKey),
         inputRowType,
         consumeRetraction,
+        tableConfig,
         isStateBackedDataViews = true)
 
     val aggMapping = aggregates.indices.map(_ + groupings.length).toArray
@@ -223,6 +227,7 @@ object AggregateUtil {
       inputFieldTypeInfo: Seq[TypeInformation[_]],
       precedingOffset: Long,
       queryConfig: StreamQueryConfig,
+      tableConfig: TableConfig,
       isRowsClause: Boolean,
       rowTimeIdx: Option[Int])
     : ProcessFunction[CRow, CRow] = {
@@ -233,6 +238,7 @@ object AggregateUtil {
         namedAggregates.map(_.getKey),
         aggregateInputType,
         needRetract,
+        tableConfig,
         isStateBackedDataViews = true)
 
     val aggregationStateType: RowTypeInfo = new RowTypeInfo(accTypes: _*)
@@ -325,14 +331,16 @@ object AggregateUtil {
     groupings: Array[Int],
     inputType: RelDataType,
     inputFieldTypeInfo: Seq[TypeInformation[_]],
-    isParserCaseSensitive: Boolean)
+    isParserCaseSensitive: Boolean,
+    tableConfig: TableConfig)
   : MapFunction[Row, Row] = {
 
     val needRetract = false
     val (aggFieldIndexes, aggregates, accTypes, _) = 
transformToAggregateFunctions(
       namedAggregates.map(_.getKey),
       inputType,
-      needRetract)
+      needRetract,
+      tableConfig)
 
     val mapReturnType: RowTypeInfo =
       createRowTypeForKeysAndAggregates(
@@ -430,14 +438,16 @@ object AggregateUtil {
       groupings: Array[Int],
       physicalInputRowType: RelDataType,
       physicalInputTypes: Seq[TypeInformation[_]],
-      isParserCaseSensitive: Boolean)
+      isParserCaseSensitive: Boolean,
+      tableConfig: TableConfig)
     : RichGroupReduceFunction[Row, Row] = {
 
     val needRetract = false
     val (aggFieldIndexes, aggregates, accTypes, _) = 
transformToAggregateFunctions(
       namedAggregates.map(_.getKey),
       physicalInputRowType,
-      needRetract)
+      needRetract,
+      tableConfig)
 
     val returnType: RowTypeInfo = createRowTypeForKeysAndAggregates(
       groupings,
@@ -543,6 +553,7 @@ object AggregateUtil {
       outputType: RelDataType,
       groupings: Array[Int],
       properties: Seq[NamedWindowProperty],
+      tableConfig: TableConfig,
       isInputCombined: Boolean = false)
     : RichGroupReduceFunction[Row, Row] = {
 
@@ -550,7 +561,8 @@ object AggregateUtil {
     val (aggFieldIndexes, aggregates, _, _) = transformToAggregateFunctions(
       namedAggregates.map(_.getKey),
       physicalInputRowType,
-      needRetract)
+      needRetract,
+      tableConfig)
 
     val aggMapping = aggregates.indices.toArray.map(_ + groupings.length)
 
@@ -695,13 +707,15 @@ object AggregateUtil {
     namedAggregates: Seq[CalcitePair[AggregateCall, String]],
     physicalInputRowType: RelDataType,
     physicalInputTypes: Seq[TypeInformation[_]],
-    groupings: Array[Int]): MapPartitionFunction[Row, Row] = {
+    groupings: Array[Int],
+    tableConfig: TableConfig): MapPartitionFunction[Row, Row] = {
 
     val needRetract = false
     val (aggFieldIndexes, aggregates, accTypes, _) = 
transformToAggregateFunctions(
       namedAggregates.map(_.getKey),
       physicalInputRowType,
-      needRetract)
+      needRetract,
+      tableConfig)
 
     val aggMapping = aggregates.indices.map(_ + groupings.length).toArray
 
@@ -767,14 +781,16 @@ object AggregateUtil {
       namedAggregates: Seq[CalcitePair[AggregateCall, String]],
       physicalInputRowType: RelDataType,
       physicalInputTypes: Seq[TypeInformation[_]],
-      groupings: Array[Int])
+      groupings: Array[Int],
+      tableConfig: TableConfig)
     : GroupCombineFunction[Row, Row] = {
 
     val needRetract = false
     val (aggFieldIndexes, aggregates, accTypes, _) = 
transformToAggregateFunctions(
       namedAggregates.map(_.getKey),
       physicalInputRowType,
-      needRetract)
+      needRetract,
+      tableConfig)
 
     val aggMapping = aggregates.indices.map(_ + groupings.length).toArray
 
@@ -831,7 +847,8 @@ object AggregateUtil {
       inputType: RelDataType,
       inputFieldTypeInfo: Seq[TypeInformation[_]],
       outputType: RelDataType,
-      groupings: Array[Int]): (
+      groupings: Array[Int],
+      tableConfig: TableConfig): (
         Option[DataSetPreAggFunction],
         Option[TypeInformation[Row]],
         Either[DataSetAggFunction, DataSetFinalAggFunction]) = {
@@ -840,7 +857,8 @@ object AggregateUtil {
     val (aggInFields, aggregates, accTypes, _) = transformToAggregateFunctions(
       namedAggregates.map(_.getKey),
       inputType,
-      needRetract)
+      needRetract,
+      tableConfig)
 
     val (gkeyOutMapping, aggOutMapping) = getOutputMappings(
       namedAggregates,
@@ -992,7 +1010,8 @@ object AggregateUtil {
       inputFieldTypeInfo: Seq[TypeInformation[_]],
       outputType: RelDataType,
       groupingKeys: Array[Int],
-      needMerge: Boolean)
+      needMerge: Boolean,
+      tableConfig: TableConfig)
     : (DataStreamAggFunction[CRow, Row, Row], RowTypeInfo, RowTypeInfo) = {
 
     val needRetract = false
@@ -1000,7 +1019,8 @@ object AggregateUtil {
       transformToAggregateFunctions(
         namedAggregates.map(_.getKey),
         inputType,
-        needRetract)
+        needRetract,
+        tableConfig)
 
     val aggMapping = aggregates.indices.toArray
     val outputArity = aggregates.length
@@ -1036,12 +1056,14 @@ object AggregateUtil {
   private[flink] def doAllSupportPartialMerge(
     aggregateCalls: Seq[AggregateCall],
     inputType: RelDataType,
-    groupKeysCount: Int): Boolean = {
+    groupKeysCount: Int,
+    tableConfig: TableConfig): Boolean = {
 
     val aggregateList = transformToAggregateFunctions(
       aggregateCalls,
       inputType,
-      needRetraction = false)._2
+      needRetraction = false,
+      tableConfig)._2
 
     doAllSupportPartialMerge(aggregateList)
   }
@@ -1121,6 +1143,7 @@ object AggregateUtil {
       aggregateCalls: Seq[AggregateCall],
       aggregateInputType: RelDataType,
       needRetraction: Boolean,
+      tableConfig: TableConfig,
       isStateBackedDataViews: Boolean = false)
   : (Array[Array[Int]],
     Array[TableAggregateFunction[_, _]],
@@ -1251,7 +1274,7 @@ object AggregateUtil {
               case DOUBLE =>
                 new DoubleAvgAggFunction
               case DECIMAL =>
-                new DecimalAvgAggFunction
+                new DecimalAvgAggFunction(tableConfig.getDecimalContext)
               case sqlType: SqlTypeName =>
                 throw new TableException(s"Avg aggregate does no support type: 
'$sqlType'")
             }

http://git-wip-us.apache.org/repos/asf/flink/blob/b31b707c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/DecimalTypeTest.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/DecimalTypeTest.scala
 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/DecimalTypeTest.scala
index 42f8008..5de6f2c 100644
--- 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/DecimalTypeTest.scala
+++ 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/DecimalTypeTest.scala
@@ -233,6 +233,13 @@ class DecimalTypeTest extends ExpressionTestBase {
       "-f0",
       "-f0",
       "-123456789.123456789123456789")
+
+    testAllApis(
+      BigDecimal("1").toExpr / BigDecimal("3"),
+      "1p / 3p",
+      "CAST('1' AS DECIMAL) / CAST('3' AS DECIMAL)",
+      "0.3333333333333333333333333333333333"
+    )
   }
 
   @Test
@@ -287,7 +294,7 @@ class DecimalTypeTest extends ExpressionTestBase {
 
   // 
----------------------------------------------------------------------------------------------
 
-  def testData = {
+  def testData: Row = {
     val testData = new Row(6)
     testData.setField(0, BigDecimal("123456789.123456789123456789").bigDecimal)
     testData.setField(1, BigDecimal("123456789123456789123456789").bigDecimal)
@@ -298,7 +305,7 @@ class DecimalTypeTest extends ExpressionTestBase {
     testData
   }
 
-  def typeInfo = {
+  def typeInfo: TypeInformation[Any] = {
     new RowTypeInfo(
       Types.DECIMAL,
       Types.DECIMAL,

http://git-wip-us.apache.org/repos/asf/flink/blob/b31b707c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/AvgFunctionTest.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/AvgFunctionTest.scala
 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/AvgFunctionTest.scala
index 0671b40..d413c6c 100644
--- 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/AvgFunctionTest.scala
+++ 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/AvgFunctionTest.scala
@@ -18,7 +18,7 @@
 
 package org.apache.flink.table.runtime.aggfunctions
 
-import java.math.BigDecimal
+import java.math.{BigDecimal, MathContext}
 
 import org.apache.flink.table.functions.AggregateFunction
 import org.apache.flink.table.functions.aggfunctions._
@@ -178,17 +178,23 @@ class DecimalAvgAggFunctionTest extends 
AggFunctionTestBase[BigDecimal, DecimalA
       null,
       null,
       null
+    ),
+    Seq(
+      new BigDecimal("0.3"),
+      new BigDecimal("0.3"),
+      new BigDecimal("0.4")
     )
   )
 
   override def expectedResults: Seq[BigDecimal] = Seq(
     BigDecimal.ZERO,
     BigDecimal.ONE,
-    null
+    null,
+    BigDecimal.ONE.divide(new BigDecimal("3"), MathContext.DECIMAL128)
   )
 
   override def aggregator: AggregateFunction[BigDecimal, 
DecimalAvgAccumulator] =
-    new DecimalAvgAggFunction()
+    new DecimalAvgAggFunction(MathContext.DECIMAL128)
 
   override def retractFunc = aggregator.getClass.getMethod("retract", accType, 
classOf[Any])
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/b31b707c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/CalcITCase.scala
----------------------------------------------------------------------
diff --git 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/CalcITCase.scala
 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/CalcITCase.scala
index 1b89229..aa37d1b 100644
--- 
a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/CalcITCase.scala
+++ 
b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/CalcITCase.scala
@@ -18,6 +18,7 @@
 
 package org.apache.flink.table.runtime.batch.table
 
+import java.math.MathContext
 import java.sql.{Date, Time, Timestamp}
 import java.util
 
@@ -41,6 +42,7 @@ import org.junit.runners.Parameterized
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
+import scala.math.BigDecimal.RoundingMode
 
 @RunWith(classOf[Parameterized])
 class CalcITCase(
@@ -330,6 +332,7 @@ class CalcITCase(
   def testAdvancedDataTypes(): Unit = {
     val env = ExecutionEnvironment.getExecutionEnvironment
     val tEnv = TableEnvironment.getTableEnvironment(env, config)
+    tEnv.getConfig.setDecimalContext(new MathContext(30))
 
     val t = env
       .fromElements((
@@ -341,10 +344,11 @@ class CalcITCase(
       .toTable(tEnv, 'a, 'b, 'c, 'd, 'e)
       .select('a, 'b, 'c, 'd, 'e, BigDecimal("11.2"), 
BigDecimal("11.2").bigDecimal,
         Date.valueOf("1984-07-12"), Time.valueOf("14:34:24"),
-        Timestamp.valueOf("1984-07-12 14:34:24"))
+        Timestamp.valueOf("1984-07-12 14:34:24"),
+        BigDecimal("1").toExpr / BigDecimal("3"))
 
     val expected = "78.454654654654654,4E+9999,1984-07-12,14:34:24,1984-07-12 
14:34:24.0," +
-      "11.2,11.2,1984-07-12,14:34:24,1984-07-12 14:34:24.0"
+      "11.2,11.2,1984-07-12,14:34:24,1984-07-12 
14:34:24.0,0.333333333333333333333333333333"
     val results = t.toDataSet[Row].collect()
     TestBaseUtils.compareResultAsText(results.asJava, expected)
   }

Reply via email to