Repository: flink Updated Branches: refs/heads/master f3cd9c059 -> b31b707cb
[FLINK-8821] [table] Fix non-terminating decimal error This closes #5608. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/b31b707c Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/b31b707c Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/b31b707c Branch: refs/heads/master Commit: b31b707cb20f34633815718ff356e187f3397620 Parents: f3cd9c0 Author: Xpray <leonxp...@gmail.com> Authored: Fri Mar 2 12:11:45 2018 +0800 Committer: Timo Walther <twal...@apache.org> Committed: Fri Mar 2 09:18:46 2018 +0100 ---------------------------------------------------------------------- .../apache/flink/table/api/TableConfig.scala | 21 ++++++++ .../flink/table/codegen/CodeGenerator.scala | 18 +++---- .../table/codegen/calls/ScalarOperators.scala | 29 ++++++++-- .../functions/aggfunctions/AvgAggFunction.scala | 7 +-- .../plan/nodes/dataset/DataSetAggregate.scala | 3 +- .../nodes/dataset/DataSetWindowAggregate.scala | 55 ++++++++++++------- .../datastream/DataStreamGroupAggregate.scala | 1 + .../DataStreamGroupWindowAggregate.scala | 6 ++- .../datastream/DataStreamOverAggregate.scala | 12 +++-- .../table/runtime/aggregate/AggregateUtil.scala | 57 ++++++++++++++------ .../table/expressions/DecimalTypeTest.scala | 11 +++- .../runtime/aggfunctions/AvgFunctionTest.scala | 12 +++-- .../table/runtime/batch/table/CalcITCase.scala | 8 ++- 13 files changed, 176 insertions(+), 64 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/b31b707c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableConfig.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableConfig.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableConfig.scala index 6448657..c78a022 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableConfig.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/api/TableConfig.scala @@ -18,6 +18,7 @@ package org.apache.flink.table.api import _root_.java.util.TimeZone +import _root_.java.math.MathContext import org.apache.flink.table.calcite.CalciteConfig @@ -42,6 +43,12 @@ class TableConfig { private var calciteConfig = CalciteConfig.DEFAULT /** + * Defines the default context for decimal division calculation. + * We use Scala's default MathContext.DECIMAL128. + */ + private var decimalContext = MathContext.DECIMAL128 + + /** * Sets the timezone for date/time/timestamp conversions. */ def setTimeZone(timeZone: TimeZone): Unit = { @@ -78,6 +85,20 @@ class TableConfig { def setCalciteConfig(calciteConfig: CalciteConfig): Unit = { this.calciteConfig = calciteConfig } + + /** + * Returns the default context for decimal division calculation. + * [[_root_.java.math.MathContext#DECIMAL128]] by default. + */ + def getDecimalContext: MathContext = decimalContext + + /** + * Sets the default context for decimal division calculation. + * [[_root_.java.math.MathContext#DECIMAL128]] by default. + */ + def setDecimalContext(mathContext: MathContext): Unit = { + this.decimalContext = mathContext + } } object TableConfig { http://git-wip-us.apache.org/repos/asf/flink/blob/b31b707c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala index 756a828..e4064d6 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/CodeGenerator.scala @@ -742,56 +742,56 @@ abstract class CodeGenerator( val right = operands(1) requireNumeric(left) requireNumeric(right) - generateArithmeticOperator("+", nullCheck, resultType, left, right) + generateArithmeticOperator("+", nullCheck, resultType, left, right, config) case PLUS | DATETIME_PLUS if isTemporal(resultType) => val left = operands.head val right = operands(1) requireTemporal(left) requireTemporal(right) - generateTemporalPlusMinus(plus = true, nullCheck, left, right) + generateTemporalPlusMinus(plus = true, nullCheck, left, right, config) case MINUS if isNumeric(resultType) => val left = operands.head val right = operands(1) requireNumeric(left) requireNumeric(right) - generateArithmeticOperator("-", nullCheck, resultType, left, right) + generateArithmeticOperator("-", nullCheck, resultType, left, right, config) case MINUS | MINUS_DATE if isTemporal(resultType) => val left = operands.head val right = operands(1) requireTemporal(left) requireTemporal(right) - generateTemporalPlusMinus(plus = false, nullCheck, left, right) + generateTemporalPlusMinus(plus = false, nullCheck, left, right, config) case MULTIPLY if isNumeric(resultType) => val left = operands.head val right = operands(1) requireNumeric(left) requireNumeric(right) - generateArithmeticOperator("*", nullCheck, resultType, left, right) + generateArithmeticOperator("*", nullCheck, resultType, left, right, config) case MULTIPLY if isTimeInterval(resultType) => val left = operands.head val right = operands(1) requireTimeInterval(left) requireNumeric(right) - generateArithmeticOperator("*", nullCheck, resultType, left, right) + generateArithmeticOperator("*", nullCheck, resultType, left, right, config) case DIVIDE | DIVIDE_INTEGER if isNumeric(resultType) => val left = operands.head val right = operands(1) requireNumeric(left) requireNumeric(right) - generateArithmeticOperator("/", nullCheck, resultType, left, right) + generateArithmeticOperator("/", nullCheck, resultType, left, right, config) case MOD if isNumeric(resultType) => val left = operands.head val right = operands(1) requireNumeric(left) requireNumeric(right) - generateArithmeticOperator("%", nullCheck, resultType, left, right) + generateArithmeticOperator("%", nullCheck, resultType, left, right, config) case UNARY_MINUS if isNumeric(resultType) => val operand = operands.head @@ -922,7 +922,7 @@ abstract class CodeGenerator( val left = operands.head val right = operands(1) requireString(left) - generateArithmeticOperator("+", nullCheck, resultType, left, right) + generateArithmeticOperator("+", nullCheck, resultType, left, right, config) // rows case ROW => http://git-wip-us.apache.org/repos/asf/flink/blob/b31b707c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala index a261b3d..57f1618 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/codegen/calls/ScalarOperators.scala @@ -17,12 +17,15 @@ */ package org.apache.flink.table.codegen.calls +import java.math.MathContext + import org.apache.calcite.avatica.util.DateTimeUtils.MILLIS_PER_DAY import org.apache.calcite.avatica.util.{DateTimeUtils, TimeUnitRange} import org.apache.calcite.util.BuiltInMethod import org.apache.flink.api.common.typeinfo.BasicTypeInfo._ import org.apache.flink.api.common.typeinfo._ import org.apache.flink.api.java.typeutils.{MapTypeInfo, ObjectArrayTypeInfo, RowTypeInfo} +import org.apache.flink.table.api.TableConfig import org.apache.flink.table.codegen.CodeGenUtils._ import org.apache.flink.table.codegen.calls.CallGenerator.generateCallIfArgsNotNull import org.apache.flink.table.codegen.{CodeGenException, CodeGenerator, GeneratedExpression} @@ -46,7 +49,8 @@ object ScalarOperators { nullCheck: Boolean, resultType: TypeInformation[_], left: GeneratedExpression, - right: GeneratedExpression): GeneratedExpression = { + right: GeneratedExpression, + config: TableConfig): GeneratedExpression = { val leftCasting = operator match { case "%" => @@ -68,7 +72,15 @@ object ScalarOperators { generateOperatorIfNotNull(nullCheck, resultType, left, right) { (leftTerm, rightTerm) => if (isDecimal(resultType)) { - s"${leftCasting(leftTerm)}.${arithOpToDecMethod(operator)}(${rightCasting(rightTerm)})" + val decMethod = arithOpToDecMethod(operator) + operator match { + // include math context for decimal division + case "/" => + val mathContext = mathContextToString(config.getDecimalContext) + s"${leftCasting(leftTerm)}.$decMethod(${rightCasting(rightTerm)}, $mathContext)" + case _ => + s"${leftCasting(leftTerm)}.$decMethod(${rightCasting(rightTerm)})" + } } else { s"($resultTypeTerm) (${leftCasting(leftTerm)} $operator ${rightCasting(rightTerm)})" } @@ -814,14 +826,15 @@ object ScalarOperators { plus: Boolean, nullCheck: Boolean, left: GeneratedExpression, - right: GeneratedExpression) + right: GeneratedExpression, + config: TableConfig) : GeneratedExpression = { val op = if (plus) "+" else "-" (left.resultType, right.resultType) match { case (l: TimeIntervalTypeInfo[_], r: TimeIntervalTypeInfo[_]) if l == r => - generateArithmeticOperator(op, nullCheck, l, left, right) + generateArithmeticOperator(op, nullCheck, l, left, right, config) case (SqlTimeTypeInfo.DATE, TimeIntervalTypeInfo.INTERVAL_MILLIS) => generateOperatorIfNotNull(nullCheck, SqlTimeTypeInfo.DATE, left, right) { @@ -1290,6 +1303,14 @@ object ScalarOperators { case _ => throw new CodeGenException(s"Unsupported decimal arithmetic operator: '$operator'") } + private def mathContextToString(mathContext: MathContext): String = mathContext match { + case MathContext.DECIMAL32 => "java.math.MathContext.DECIMAL32" + case MathContext.DECIMAL64 => "java.math.MathContext.DECIMAL64" + case MathContext.DECIMAL128 => "java.math.MathContext.DECIMAL128" + case MathContext.UNLIMITED => "java.math.MathContext.UNLIMITED" + case _ => s"""new java.math.MathContext("$mathContext")""" + } + private def numericCasting( operandType: TypeInformation[_], resultType: TypeInformation[_]) http://git-wip-us.apache.org/repos/asf/flink/blob/b31b707c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala index b651c42..26621b7 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/functions/aggfunctions/AvgAggFunction.scala @@ -17,7 +17,7 @@ */ package org.apache.flink.table.functions.aggfunctions -import java.math.{BigDecimal, BigInteger} +import java.math.{BigDecimal, BigInteger, MathContext} import java.lang.{Iterable => JIterable} import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} @@ -295,7 +295,8 @@ class DecimalAvgAccumulator extends JTuple2[BigDecimal, Long] { /** * Base class for built-in Big Decimal Avg aggregate function */ -class DecimalAvgAggFunction extends AggregateFunction[BigDecimal, DecimalAvgAccumulator] { +class DecimalAvgAggFunction(context: MathContext) + extends AggregateFunction[BigDecimal, DecimalAvgAccumulator] { override def createAccumulator(): DecimalAvgAccumulator = { new DecimalAvgAccumulator @@ -321,7 +322,7 @@ class DecimalAvgAggFunction extends AggregateFunction[BigDecimal, DecimalAvgAccu if (acc.f1 == 0) { null.asInstanceOf[BigDecimal] } else { - acc.f0.divide(BigDecimal.valueOf(acc.f1)) + acc.f0.divide(BigDecimal.valueOf(acc.f1), context) } } http://git-wip-us.apache.org/repos/asf/flink/blob/b31b707c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala index 7dd307b..07dcf79 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetAggregate.scala @@ -110,7 +110,8 @@ class DataSetAggregate( input.getRowType, inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes, rowRelDataType, - grouping) + grouping, + tableEnv.getConfig) val aggString = aggregationToString(inputType, grouping, getRowType, namedAggregates, Nil) http://git-wip-us.apache.org/repos/asf/flink/blob/b31b707c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala index 745c4ed..53748f5 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/dataset/DataSetWindowAggregate.scala @@ -25,7 +25,7 @@ import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} import org.apache.flink.api.common.operators.Order import org.apache.flink.api.java.DataSet import org.apache.flink.api.java.typeutils.{ResultTypeQueryable, RowTypeInfo} -import org.apache.flink.table.api.{BatchQueryConfig, BatchTableEnvironment} +import org.apache.flink.table.api.{BatchQueryConfig, BatchTableEnvironment, TableConfig} import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.codegen.AggregationCodeGenerator @@ -127,11 +127,12 @@ class DataSetWindowAggregate( generator, inputDS, isTimeIntervalLiteral(size), - caseSensitive) + caseSensitive, + tableEnv.getConfig) case SessionGroupWindow(_, timeField, gap) if isTimePoint(timeField.resultType) || isLong(timeField.resultType) => - createEventTimeSessionWindowDataSet(generator, inputDS, caseSensitive) + createEventTimeSessionWindowDataSet(generator, inputDS, caseSensitive, tableEnv.getConfig) case SlidingGroupWindow(_, timeField, size, slide) if isTimePoint(timeField.resultType) || isLong(timeField.resultType) => @@ -141,7 +142,8 @@ class DataSetWindowAggregate( isTimeIntervalLiteral(size), asLong(size), asLong(slide), - caseSensitive) + caseSensitive, + tableEnv.getConfig) case _ => throw new UnsupportedOperationException( @@ -153,7 +155,8 @@ class DataSetWindowAggregate( generator: AggregationCodeGenerator, inputDS: DataSet[Row], isTimeWindow: Boolean, - isParserCaseSensitive: Boolean): DataSet[Row] = { + isParserCaseSensitive: Boolean, + tableConfig: TableConfig): DataSet[Row] = { val input = inputNode.asInstanceOf[DataSetRel] @@ -164,7 +167,8 @@ class DataSetWindowAggregate( grouping, input.getRowType, inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes, - isParserCaseSensitive) + isParserCaseSensitive, + tableConfig) val groupReduceFunction = createDataSetWindowAggregationGroupReduceFunction( generator, window, @@ -173,7 +177,8 @@ class DataSetWindowAggregate( inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes, getRowType, grouping, - namedProperties) + namedProperties, + tableConfig) val mappedInput = inputDS .map(mapFunction) @@ -215,7 +220,8 @@ class DataSetWindowAggregate( private[this] def createEventTimeSessionWindowDataSet( generator: AggregationCodeGenerator, inputDS: DataSet[Row], - isParserCaseSensitive: Boolean): DataSet[Row] = { + isParserCaseSensitive: Boolean, + tableConfig: TableConfig): DataSet[Row] = { val input = inputNode.asInstanceOf[DataSetRel] @@ -230,7 +236,8 @@ class DataSetWindowAggregate( grouping, input.getRowType, inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes, - isParserCaseSensitive) + isParserCaseSensitive, + tableConfig) val mappedInput = inputDS.map(mapFunction).name(prepareOperatorName) @@ -243,7 +250,8 @@ class DataSetWindowAggregate( if (doAllSupportPartialMerge( namedAggregates.map(_.getKey), inputType, - grouping.length)) { + grouping.length, + tableConfig)) { // gets the window-start and window-end position in the intermediate result. val windowStartPos = rowTimeFieldPos @@ -257,7 +265,8 @@ class DataSetWindowAggregate( namedAggregates, input.getRowType, inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes, - grouping) + grouping, + tableConfig) // create groupReduceFunction for calculating the aggregations val groupReduceFunction = createDataSetWindowAggregationGroupReduceFunction( @@ -269,6 +278,7 @@ class DataSetWindowAggregate( rowRelDataType, grouping, namedProperties, + tableConfig, isInputCombined = true) mappedInput @@ -289,7 +299,8 @@ class DataSetWindowAggregate( namedAggregates, input.getRowType, inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes, - grouping) + grouping, + tableConfig) // create groupReduceFunction for calculating the aggregations val groupReduceFunction = createDataSetWindowAggregationGroupReduceFunction( @@ -301,6 +312,7 @@ class DataSetWindowAggregate( rowRelDataType, grouping, namedProperties, + tableConfig, isInputCombined = true) mappedInput.sortPartition(rowTimeFieldPos, Order.ASCENDING) @@ -326,7 +338,8 @@ class DataSetWindowAggregate( inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes, rowRelDataType, grouping, - namedProperties) + namedProperties, + tableConfig) mappedInput.groupBy(groupingKeys: _*) .sortGroup(rowTimeFieldPos, Order.ASCENDING) @@ -343,7 +356,8 @@ class DataSetWindowAggregate( inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes, rowRelDataType, grouping, - namedProperties) + namedProperties, + tableConfig) mappedInput.sortPartition(rowTimeFieldPos, Order.ASCENDING).setParallelism(1) .reduceGroup(groupReduceFunction) @@ -360,7 +374,8 @@ class DataSetWindowAggregate( isTimeWindow: Boolean, size: Long, slide: Long, - isParserCaseSensitive: Boolean) + isParserCaseSensitive: Boolean, + tableConfig: TableConfig) : DataSet[Row] = { val input = inputNode.asInstanceOf[DataSetRel] @@ -374,7 +389,8 @@ class DataSetWindowAggregate( grouping, input.getRowType, inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes, - isParserCaseSensitive) + isParserCaseSensitive, + tableConfig) val mappedDataSet = inputDS .map(mapFunction) @@ -389,7 +405,8 @@ class DataSetWindowAggregate( val isPartial = doAllSupportPartialMerge( namedAggregates.map(_.getKey), inputType, - grouping.length) + grouping.length, + tableConfig) // only pre-tumble if it is worth it val isLittleTumblingSize = determineLargestTumblingSize(size, slide) <= 1 @@ -411,7 +428,8 @@ class DataSetWindowAggregate( grouping, input.getRowType, inputDS.getType.asInstanceOf[RowTypeInfo].getFieldTypes, - isParserCaseSensitive) + isParserCaseSensitive, + tableConfig) mappedDataSet.asInstanceOf[DataSet[Row]] .groupBy(groupingKeysAndAlignedRowtime: _*) @@ -451,6 +469,7 @@ class DataSetWindowAggregate( rowRelDataType, grouping, namedProperties, + tableConfig, isInputCombined = false) // gets the window-start position in the intermediate result. http://git-wip-us.apache.org/repos/asf/flink/blob/b31b707c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala index 71de57c..5f4b186 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupAggregate.scala @@ -138,6 +138,7 @@ class DataStreamGroupAggregate( inputSchema.fieldTypeInfos, groupings, queryConfig, + tableEnv.getConfig, DataStreamRetractionRules.isAccRetract(this), DataStreamRetractionRules.isAccRetract(getInput)) http://git-wip-us.apache.org/repos/asf/flink/blob/b31b707c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala index d527dc8..0a014b6 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamGroupWindowAggregate.scala @@ -207,7 +207,8 @@ class DataStreamGroupWindowAggregate( inputSchema.fieldTypeInfos, schema.relDataType, grouping, - needMerge) + needMerge, + tableEnv.getConfig) windowedStream .aggregate(aggFunction, windowFunction, accumulatorRowType, aggResultRowType, outRowType) @@ -232,7 +233,8 @@ class DataStreamGroupWindowAggregate( inputSchema.fieldTypeInfos, schema.relDataType, Array[Int](), - needMerge) + needMerge, + tableEnv.getConfig) windowedStream .aggregate(aggFunction, windowFunction, accumulatorRowType, aggResultRowType, outRowType) http://git-wip-us.apache.org/repos/asf/flink/blob/b31b707c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala index 635c7bc..c1693d9 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/plan/nodes/datastream/DataStreamOverAggregate.scala @@ -28,7 +28,7 @@ import org.apache.calcite.rel.{RelNode, RelWriter, SingleRel} import org.apache.calcite.rex.RexLiteral import org.apache.flink.api.java.functions.NullByteKeySelector import org.apache.flink.streaming.api.datastream.DataStream -import org.apache.flink.table.api.{StreamQueryConfig, StreamTableEnvironment, TableException} +import org.apache.flink.table.api.{StreamQueryConfig, StreamTableEnvironment, TableConfig, TableException} import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.codegen.AggregationCodeGenerator import org.apache.flink.table.plan.nodes.OverAggregate @@ -172,6 +172,7 @@ class DataStreamOverAggregate( // unbounded OVER window createUnboundedAndCurrentRowOverWindow( queryConfig, + tableEnv.getConfig, generator, inputDS, rowTimeIdx, @@ -188,7 +189,8 @@ class DataStreamOverAggregate( inputDS, rowTimeIdx, aggregateInputType, - isRowsClause = overWindow.isRows) + isRowsClause = overWindow.isRows, + tableEnv.getConfig) } else { throw new TableException("OVER RANGE FOLLOWING windows are not supported yet.") } @@ -196,6 +198,7 @@ class DataStreamOverAggregate( def createUnboundedAndCurrentRowOverWindow( queryConfig: StreamQueryConfig, + tableConfig: TableConfig, generator: AggregationCodeGenerator, inputDS: DataStream[CRow], rowTimeIdx: Option[Int], @@ -219,6 +222,7 @@ class DataStreamOverAggregate( inputSchema.typeInfo, inputSchema.fieldTypeInfos, queryConfig, + tableConfig, rowTimeIdx, partitionKeys.nonEmpty, isRowsClause) @@ -249,7 +253,8 @@ class DataStreamOverAggregate( inputDS: DataStream[CRow], rowTimeIdx: Option[Int], aggregateInputType: RelDataType, - isRowsClause: Boolean): DataStream[CRow] = { + isRowsClause: Boolean, + tableConfig: TableConfig): DataStream[CRow] = { val overWindow: Group = logicWindow.groups.get(0) @@ -272,6 +277,7 @@ class DataStreamOverAggregate( inputSchema.fieldTypeInfos, precedingOffset, queryConfig, + tableConfig, isRowsClause, rowTimeIdx ) http://git-wip-us.apache.org/repos/asf/flink/blob/b31b707c/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala index 361a87e..df9b1c5 100644 --- a/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala +++ b/flink-libraries/flink-table/src/main/scala/org/apache/flink/table/runtime/aggregate/AggregateUtil.scala @@ -32,7 +32,7 @@ import org.apache.flink.streaming.api.functions.ProcessFunction import org.apache.flink.streaming.api.functions.windowing.{AllWindowFunction, WindowFunction} import org.apache.flink.streaming.api.windowing.windows.{Window => DataStreamWindow} import org.apache.flink.table.api.dataview.DataViewSpec -import org.apache.flink.table.api.{StreamQueryConfig, TableException} +import org.apache.flink.table.api.{StreamQueryConfig, TableConfig, TableException} import org.apache.flink.table.calcite.FlinkRelBuilder.NamedWindowProperty import org.apache.flink.table.calcite.FlinkTypeFactory import org.apache.flink.table.codegen.AggregationCodeGenerator @@ -78,6 +78,7 @@ object AggregateUtil { inputTypeInfo: TypeInformation[Row], inputFieldTypeInfo: Seq[TypeInformation[_]], queryConfig: StreamQueryConfig, + tableConfig: TableConfig, rowTimeIdx: Option[Int], isPartitioned: Boolean, isRowsClause: Boolean) @@ -88,6 +89,7 @@ object AggregateUtil { namedAggregates.map(_.getKey), aggregateInputType, needRetraction = false, + tableConfig, isStateBackedDataViews = true) val aggregationStateType: RowTypeInfo = new RowTypeInfo(accTypes: _*) @@ -159,6 +161,7 @@ object AggregateUtil { inputFieldTypes: Seq[TypeInformation[_]], groupings: Array[Int], queryConfig: StreamQueryConfig, + tableConfig: TableConfig, generateRetraction: Boolean, consumeRetraction: Boolean): ProcessFunction[CRow, CRow] = { @@ -167,6 +170,7 @@ object AggregateUtil { namedAggregates.map(_.getKey), inputRowType, consumeRetraction, + tableConfig, isStateBackedDataViews = true) val aggMapping = aggregates.indices.map(_ + groupings.length).toArray @@ -223,6 +227,7 @@ object AggregateUtil { inputFieldTypeInfo: Seq[TypeInformation[_]], precedingOffset: Long, queryConfig: StreamQueryConfig, + tableConfig: TableConfig, isRowsClause: Boolean, rowTimeIdx: Option[Int]) : ProcessFunction[CRow, CRow] = { @@ -233,6 +238,7 @@ object AggregateUtil { namedAggregates.map(_.getKey), aggregateInputType, needRetract, + tableConfig, isStateBackedDataViews = true) val aggregationStateType: RowTypeInfo = new RowTypeInfo(accTypes: _*) @@ -325,14 +331,16 @@ object AggregateUtil { groupings: Array[Int], inputType: RelDataType, inputFieldTypeInfo: Seq[TypeInformation[_]], - isParserCaseSensitive: Boolean) + isParserCaseSensitive: Boolean, + tableConfig: TableConfig) : MapFunction[Row, Row] = { val needRetract = false val (aggFieldIndexes, aggregates, accTypes, _) = transformToAggregateFunctions( namedAggregates.map(_.getKey), inputType, - needRetract) + needRetract, + tableConfig) val mapReturnType: RowTypeInfo = createRowTypeForKeysAndAggregates( @@ -430,14 +438,16 @@ object AggregateUtil { groupings: Array[Int], physicalInputRowType: RelDataType, physicalInputTypes: Seq[TypeInformation[_]], - isParserCaseSensitive: Boolean) + isParserCaseSensitive: Boolean, + tableConfig: TableConfig) : RichGroupReduceFunction[Row, Row] = { val needRetract = false val (aggFieldIndexes, aggregates, accTypes, _) = transformToAggregateFunctions( namedAggregates.map(_.getKey), physicalInputRowType, - needRetract) + needRetract, + tableConfig) val returnType: RowTypeInfo = createRowTypeForKeysAndAggregates( groupings, @@ -543,6 +553,7 @@ object AggregateUtil { outputType: RelDataType, groupings: Array[Int], properties: Seq[NamedWindowProperty], + tableConfig: TableConfig, isInputCombined: Boolean = false) : RichGroupReduceFunction[Row, Row] = { @@ -550,7 +561,8 @@ object AggregateUtil { val (aggFieldIndexes, aggregates, _, _) = transformToAggregateFunctions( namedAggregates.map(_.getKey), physicalInputRowType, - needRetract) + needRetract, + tableConfig) val aggMapping = aggregates.indices.toArray.map(_ + groupings.length) @@ -695,13 +707,15 @@ object AggregateUtil { namedAggregates: Seq[CalcitePair[AggregateCall, String]], physicalInputRowType: RelDataType, physicalInputTypes: Seq[TypeInformation[_]], - groupings: Array[Int]): MapPartitionFunction[Row, Row] = { + groupings: Array[Int], + tableConfig: TableConfig): MapPartitionFunction[Row, Row] = { val needRetract = false val (aggFieldIndexes, aggregates, accTypes, _) = transformToAggregateFunctions( namedAggregates.map(_.getKey), physicalInputRowType, - needRetract) + needRetract, + tableConfig) val aggMapping = aggregates.indices.map(_ + groupings.length).toArray @@ -767,14 +781,16 @@ object AggregateUtil { namedAggregates: Seq[CalcitePair[AggregateCall, String]], physicalInputRowType: RelDataType, physicalInputTypes: Seq[TypeInformation[_]], - groupings: Array[Int]) + groupings: Array[Int], + tableConfig: TableConfig) : GroupCombineFunction[Row, Row] = { val needRetract = false val (aggFieldIndexes, aggregates, accTypes, _) = transformToAggregateFunctions( namedAggregates.map(_.getKey), physicalInputRowType, - needRetract) + needRetract, + tableConfig) val aggMapping = aggregates.indices.map(_ + groupings.length).toArray @@ -831,7 +847,8 @@ object AggregateUtil { inputType: RelDataType, inputFieldTypeInfo: Seq[TypeInformation[_]], outputType: RelDataType, - groupings: Array[Int]): ( + groupings: Array[Int], + tableConfig: TableConfig): ( Option[DataSetPreAggFunction], Option[TypeInformation[Row]], Either[DataSetAggFunction, DataSetFinalAggFunction]) = { @@ -840,7 +857,8 @@ object AggregateUtil { val (aggInFields, aggregates, accTypes, _) = transformToAggregateFunctions( namedAggregates.map(_.getKey), inputType, - needRetract) + needRetract, + tableConfig) val (gkeyOutMapping, aggOutMapping) = getOutputMappings( namedAggregates, @@ -992,7 +1010,8 @@ object AggregateUtil { inputFieldTypeInfo: Seq[TypeInformation[_]], outputType: RelDataType, groupingKeys: Array[Int], - needMerge: Boolean) + needMerge: Boolean, + tableConfig: TableConfig) : (DataStreamAggFunction[CRow, Row, Row], RowTypeInfo, RowTypeInfo) = { val needRetract = false @@ -1000,7 +1019,8 @@ object AggregateUtil { transformToAggregateFunctions( namedAggregates.map(_.getKey), inputType, - needRetract) + needRetract, + tableConfig) val aggMapping = aggregates.indices.toArray val outputArity = aggregates.length @@ -1036,12 +1056,14 @@ object AggregateUtil { private[flink] def doAllSupportPartialMerge( aggregateCalls: Seq[AggregateCall], inputType: RelDataType, - groupKeysCount: Int): Boolean = { + groupKeysCount: Int, + tableConfig: TableConfig): Boolean = { val aggregateList = transformToAggregateFunctions( aggregateCalls, inputType, - needRetraction = false)._2 + needRetraction = false, + tableConfig)._2 doAllSupportPartialMerge(aggregateList) } @@ -1121,6 +1143,7 @@ object AggregateUtil { aggregateCalls: Seq[AggregateCall], aggregateInputType: RelDataType, needRetraction: Boolean, + tableConfig: TableConfig, isStateBackedDataViews: Boolean = false) : (Array[Array[Int]], Array[TableAggregateFunction[_, _]], @@ -1251,7 +1274,7 @@ object AggregateUtil { case DOUBLE => new DoubleAvgAggFunction case DECIMAL => - new DecimalAvgAggFunction + new DecimalAvgAggFunction(tableConfig.getDecimalContext) case sqlType: SqlTypeName => throw new TableException(s"Avg aggregate does no support type: '$sqlType'") } http://git-wip-us.apache.org/repos/asf/flink/blob/b31b707c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/DecimalTypeTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/DecimalTypeTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/DecimalTypeTest.scala index 42f8008..5de6f2c 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/DecimalTypeTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/expressions/DecimalTypeTest.scala @@ -233,6 +233,13 @@ class DecimalTypeTest extends ExpressionTestBase { "-f0", "-f0", "-123456789.123456789123456789") + + testAllApis( + BigDecimal("1").toExpr / BigDecimal("3"), + "1p / 3p", + "CAST('1' AS DECIMAL) / CAST('3' AS DECIMAL)", + "0.3333333333333333333333333333333333" + ) } @Test @@ -287,7 +294,7 @@ class DecimalTypeTest extends ExpressionTestBase { // ---------------------------------------------------------------------------------------------- - def testData = { + def testData: Row = { val testData = new Row(6) testData.setField(0, BigDecimal("123456789.123456789123456789").bigDecimal) testData.setField(1, BigDecimal("123456789123456789123456789").bigDecimal) @@ -298,7 +305,7 @@ class DecimalTypeTest extends ExpressionTestBase { testData } - def typeInfo = { + def typeInfo: TypeInformation[Any] = { new RowTypeInfo( Types.DECIMAL, Types.DECIMAL, http://git-wip-us.apache.org/repos/asf/flink/blob/b31b707c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/AvgFunctionTest.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/AvgFunctionTest.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/AvgFunctionTest.scala index 0671b40..d413c6c 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/AvgFunctionTest.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/aggfunctions/AvgFunctionTest.scala @@ -18,7 +18,7 @@ package org.apache.flink.table.runtime.aggfunctions -import java.math.BigDecimal +import java.math.{BigDecimal, MathContext} import org.apache.flink.table.functions.AggregateFunction import org.apache.flink.table.functions.aggfunctions._ @@ -178,17 +178,23 @@ class DecimalAvgAggFunctionTest extends AggFunctionTestBase[BigDecimal, DecimalA null, null, null + ), + Seq( + new BigDecimal("0.3"), + new BigDecimal("0.3"), + new BigDecimal("0.4") ) ) override def expectedResults: Seq[BigDecimal] = Seq( BigDecimal.ZERO, BigDecimal.ONE, - null + null, + BigDecimal.ONE.divide(new BigDecimal("3"), MathContext.DECIMAL128) ) override def aggregator: AggregateFunction[BigDecimal, DecimalAvgAccumulator] = - new DecimalAvgAggFunction() + new DecimalAvgAggFunction(MathContext.DECIMAL128) override def retractFunc = aggregator.getClass.getMethod("retract", accType, classOf[Any]) } http://git-wip-us.apache.org/repos/asf/flink/blob/b31b707c/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/CalcITCase.scala ---------------------------------------------------------------------- diff --git a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/CalcITCase.scala b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/CalcITCase.scala index 1b89229..aa37d1b 100644 --- a/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/CalcITCase.scala +++ b/flink-libraries/flink-table/src/test/scala/org/apache/flink/table/runtime/batch/table/CalcITCase.scala @@ -18,6 +18,7 @@ package org.apache.flink.table.runtime.batch.table +import java.math.MathContext import java.sql.{Date, Time, Timestamp} import java.util @@ -41,6 +42,7 @@ import org.junit.runners.Parameterized import scala.collection.JavaConverters._ import scala.collection.mutable +import scala.math.BigDecimal.RoundingMode @RunWith(classOf[Parameterized]) class CalcITCase( @@ -330,6 +332,7 @@ class CalcITCase( def testAdvancedDataTypes(): Unit = { val env = ExecutionEnvironment.getExecutionEnvironment val tEnv = TableEnvironment.getTableEnvironment(env, config) + tEnv.getConfig.setDecimalContext(new MathContext(30)) val t = env .fromElements(( @@ -341,10 +344,11 @@ class CalcITCase( .toTable(tEnv, 'a, 'b, 'c, 'd, 'e) .select('a, 'b, 'c, 'd, 'e, BigDecimal("11.2"), BigDecimal("11.2").bigDecimal, Date.valueOf("1984-07-12"), Time.valueOf("14:34:24"), - Timestamp.valueOf("1984-07-12 14:34:24")) + Timestamp.valueOf("1984-07-12 14:34:24"), + BigDecimal("1").toExpr / BigDecimal("3")) val expected = "78.454654654654654,4E+9999,1984-07-12,14:34:24,1984-07-12 14:34:24.0," + - "11.2,11.2,1984-07-12,14:34:24,1984-07-12 14:34:24.0" + "11.2,11.2,1984-07-12,14:34:24,1984-07-12 14:34:24.0,0.333333333333333333333333333333" val results = t.toDataSet[Row].collect() TestBaseUtils.compareResultAsText(results.asJava, expected) }