This is an automated email from the ASF dual-hosted git repository. jark pushed a commit to branch release-1.9 in repository https://gitbox.apache.org/repos/asf/flink.git
The following commit(s) were added to refs/heads/release-1.9 by this push: new 5c0521d [FLINK-13314][table-planner-blink] Correct resultType of some PlannerExpression when operands contains DecimalTypeInfo or BigDecimalTypeInfo in Blink planner 5c0521d is described below commit 5c0521d846895598ef4b10a6a66c1b803a3504a6 Author: beyond1920 <beyond1...@126.com> AuthorDate: Wed Jul 17 23:01:12 2019 +0800 [FLINK-13314][table-planner-blink] Correct resultType of some PlannerExpression when operands contains DecimalTypeInfo or BigDecimalTypeInfo in Blink planner This also fix some minor bugs: - Fix minor bug in RexNodeConverter when convert between and not between to RexNode. - Fix minor bug in PlannerExpressionConverter when convert DataType to TypeInformation. This closes #9152 --- .../flink/table/expressions/RexNodeConverter.java | 16 +- .../expressions/PlannerExpressionConverter.scala | 12 +- .../table/expressions/ReturnTypeInference.scala | 217 ++++++++ .../flink/table/expressions/arithmetic.scala | 25 +- .../flink/table/expressions/mathExpressions.scala | 12 +- .../table/runtime/batch/sql/DecimalITCase.scala | 5 +- .../batch/{sql => table}/DecimalITCase.scala | 546 ++++++++------------- 7 files changed, 462 insertions(+), 371 deletions(-) diff --git a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/expressions/RexNodeConverter.java b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/expressions/RexNodeConverter.java index 5528571..5dfeb97 100644 --- a/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/expressions/RexNodeConverter.java +++ b/flink-table/flink-table-planner-blink/src/main/java/org/apache/flink/table/expressions/RexNodeConverter.java @@ -368,16 +368,24 @@ public class RexNodeConverter implements ExpressionVisitor<RexNode> { private RexNode convertNotBetween(List<Expression> children) { List<RexNode> childrenRexNode = convertCallChildren(children); + Preconditions.checkArgument(childrenRexNode.size() == 3); + RexNode expr = childrenRexNode.get(0); + RexNode lowerBound = childrenRexNode.get(1); + RexNode upperBound = childrenRexNode.get(2); return relBuilder.or( - relBuilder.call(FlinkSqlOperatorTable.LESS_THAN, childrenRexNode), - relBuilder.call(FlinkSqlOperatorTable.GREATER_THAN, childrenRexNode)); + relBuilder.call(FlinkSqlOperatorTable.LESS_THAN, expr, lowerBound), + relBuilder.call(FlinkSqlOperatorTable.GREATER_THAN, expr, upperBound)); } private RexNode convertBetween(List<Expression> children) { List<RexNode> childrenRexNode = convertCallChildren(children); + Preconditions.checkArgument(childrenRexNode.size() == 3); + RexNode expr = childrenRexNode.get(0); + RexNode lowerBound = childrenRexNode.get(1); + RexNode upperBound = childrenRexNode.get(2); return relBuilder.and( - relBuilder.call(FlinkSqlOperatorTable.GREATER_THAN_OR_EQUAL, childrenRexNode), - relBuilder.call(FlinkSqlOperatorTable.LESS_THAN_OR_EQUAL, childrenRexNode)); + relBuilder.call(FlinkSqlOperatorTable.GREATER_THAN_OR_EQUAL, expr, lowerBound), + relBuilder.call(FlinkSqlOperatorTable.LESS_THAN_OR_EQUAL, expr, upperBound)); } private RexNode convertCeil(List<Expression> children) { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/PlannerExpressionConverter.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/PlannerExpressionConverter.scala index 8b5dada..f53aa1e 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/PlannerExpressionConverter.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/PlannerExpressionConverter.scala @@ -25,7 +25,7 @@ import org.apache.flink.table.expressions.{E => PlannerE, UUID => PlannerUUID} import org.apache.flink.table.functions._ import org.apache.flink.table.types.logical.LogicalTypeRoot.{CHAR, DECIMAL, SYMBOL, TIMESTAMP_WITHOUT_TIME_ZONE} import org.apache.flink.table.types.logical.utils.LogicalTypeChecks._ -import org.apache.flink.table.types.utils.TypeConversions.fromDataTypeToLegacyInfo +import org.apache.flink.table.types.TypeInfoDataTypeConverter.fromDataTypeToTypeInfo import _root_.scala.collection.JavaConverters._ @@ -53,14 +53,14 @@ class PlannerExpressionConverter private extends ApiExpressionVisitor[PlannerExp assert(children.size == 2) return Cast( children.head.accept(this), - fromDataTypeToLegacyInfo( + fromDataTypeToTypeInfo( children(1).asInstanceOf[TypeLiteralExpression].getOutputDataType)) case REINTERPRET_CAST => assert(children.size == 3) Reinterpret( children.head.accept(this), - fromDataTypeToLegacyInfo( + fromDataTypeToTypeInfo( children(1).asInstanceOf[TypeLiteralExpression].getOutputDataType), getValue[Boolean](children(2).accept(this))) @@ -749,7 +749,7 @@ class PlannerExpressionConverter private extends ApiExpressionVisitor[PlannerExp } } - fromDataTypeToLegacyInfo(literal.getOutputDataType) + fromDataTypeToTypeInfo(literal.getOutputDataType) } private def getSymbol(symbol: TableSymbol): PlannerSymbol = symbol match { @@ -786,7 +786,7 @@ class PlannerExpressionConverter private extends ApiExpressionVisitor[PlannerExp override def visit(fieldReference: FieldReferenceExpression): PlannerExpression = { PlannerResolvedFieldReference( fieldReference.getName, - fromDataTypeToLegacyInfo(fieldReference.getOutputDataType)) + fromDataTypeToTypeInfo(fieldReference.getOutputDataType)) } override def visit(fieldReference: UnresolvedReferenceExpression) @@ -834,7 +834,7 @@ class PlannerExpressionConverter private extends ApiExpressionVisitor[PlannerExp private def translateWindowReference(reference: Expression): PlannerExpression = reference match { case expr : LocalReferenceExpression => - WindowReference(expr.getName, Some(fromDataTypeToLegacyInfo(expr.getOutputDataType))) + WindowReference(expr.getName, Some(fromDataTypeToTypeInfo(expr.getOutputDataType))) //just because how the datastream is converted to table case expr: UnresolvedReferenceExpression => UnresolvedFieldReference(expr.getName) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/ReturnTypeInference.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/ReturnTypeInference.scala new file mode 100644 index 0000000..2a333ad --- /dev/null +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/ReturnTypeInference.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.flink.table.expressions + +import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.table.api.TableException +import org.apache.flink.table.calcite.{FlinkTypeFactory, FlinkTypeSystem} +import org.apache.flink.table.types.TypeInfoLogicalTypeConverter.{fromLogicalTypeToTypeInfo, fromTypeInfoToLogicalType} +import org.apache.flink.table.types.logical.{DecimalType, LogicalType} +import org.apache.flink.table.typeutils.{BigDecimalTypeInfo, DecimalTypeInfo, TypeCoercion} + +import org.apache.calcite.rel.`type`.RelDataType +import org.apache.calcite.sql.`type`.SqlTypeUtil + +import scala.collection.JavaConverters._ + +object ReturnTypeInference { + + private lazy val typeSystem = new FlinkTypeSystem + private lazy val typeFactory = new FlinkTypeFactory(typeSystem) + + /** + * Infer resultType of [[Minus]] expression. + * The decimal type inference keeps consistent with Calcite + * [[org.apache.calcite.sql.type.ReturnTypes.NULLABLE_SUM]] which is the return type of + * [[org.apache.calcite.sql.fun.SqlStdOperatorTable.MINUS]]. + * + * @param minus minus Expression + * @return result type + */ + def inferMinus(minus: Minus): TypeInformation[_] = inferPlusOrMinus(minus) + + /** + * Infer resultType of [[Plus]] expression. + * The decimal type inference keeps consistent with Calcite + * [[org.apache.calcite.sql.type.ReturnTypes.NULLABLE_SUM]] which is the return type of + * * [[org.apache.calcite.sql.fun.SqlStdOperatorTable.PLUS]]. + * + * @param plus plus Expression + * @return result type + */ + def inferPlus(plus: Plus): TypeInformation[_] = inferPlusOrMinus(plus) + + private def inferPlusOrMinus(op: BinaryArithmetic): TypeInformation[_] = { + val decimalTypeInference = ( + leftType: RelDataType, + rightType: RelDataType, + wideResultType: LogicalType) => { + if (SqlTypeUtil.isExactNumeric(leftType) && + SqlTypeUtil.isExactNumeric(rightType) && + (SqlTypeUtil.isDecimal(leftType) || SqlTypeUtil.isDecimal(rightType))) { + val lp = leftType.getPrecision + val ls = leftType.getScale + val rp = rightType.getPrecision + val rs = rightType.getScale + val scale = Math.max(ls, rs) + assert(scale <= typeSystem.getMaxNumericScale) + var precision = Math.max(lp - ls, rp - rs) + scale + 1 + precision = Math.min(precision, typeSystem.getMaxNumericPrecision) + assert(precision > 0) + fromLogicalTypeToTypeInfo(wideResultType) match { + case _: DecimalTypeInfo => DecimalTypeInfo.of(precision, scale) + case _: BigDecimalTypeInfo => BigDecimalTypeInfo.of(precision, scale) + } + } else { + val resultType = typeFactory.leastRestrictive( + List(leftType, rightType).asJava) + fromLogicalTypeToTypeInfo(FlinkTypeFactory.toLogicalType(resultType)) + } + } + inferBinaryArithmetic(op, decimalTypeInference, t => fromLogicalTypeToTypeInfo(t)) + } + + /** + * Infer resultType of [[Mul]] expression. + * The decimal type inference keeps consistent with Calcite + * [[org.apache.calcite.sql.type.ReturnTypes.PRODUCT_NULLABLE]] which is the return type of + * * * [[org.apache.calcite.sql.fun.SqlStdOperatorTable.MULTIPLY]]. + * + * @param mul mul Expression + * @return result type + */ + def inferMul(mul: Mul): TypeInformation[_] = { + val decimalTypeInference = ( + leftType: RelDataType, + rightType: RelDataType) => typeFactory.createDecimalProduct(leftType, rightType) + inferDivOrMul(mul, decimalTypeInference) + } + + /** + * Infer resultType of [[Div]] expression. + * The decimal type inference keeps consistent with + * [[org.apache.flink.table.calcite.type.FlinkReturnTypes.FLINK_QUOTIENT_NULLABLE]] which + * is the return type of [[org.apache.flink.table.functions.sql.FlinkSqlOperatorTable.DIVIDE]]. + * + * @param div div Expression + * @return result type + */ + def inferDiv(div: Div): TypeInformation[_] = { + val decimalTypeInference = ( + leftType: RelDataType, + rightType: RelDataType) => typeFactory.createDecimalQuotient(leftType, rightType) + inferDivOrMul(div, decimalTypeInference) + } + + private def inferDivOrMul( + op: BinaryArithmetic, + decimalTypeInfer: (RelDataType, RelDataType) => RelDataType + ): TypeInformation[_] = { + val decimalFunc = ( + leftType: RelDataType, + rightType: RelDataType, + _: LogicalType) => { + val decimalType = decimalTypeInfer(leftType, rightType) + if (decimalType != null) { + fromLogicalTypeToTypeInfo(FlinkTypeFactory.toLogicalType(decimalType)) + } else { + val resultType = typeFactory.leastRestrictive( + List(leftType, rightType).asJava) + fromLogicalTypeToTypeInfo(FlinkTypeFactory.toLogicalType(resultType)) + } + } + val nonDecimalType = op match { + case _: Div => (_: LogicalType) => BasicTypeInfo.DOUBLE_TYPE_INFO + case _: Mul => (t: LogicalType) => fromLogicalTypeToTypeInfo(t) + } + inferBinaryArithmetic(op, decimalFunc, nonDecimalType) + } + + private def inferBinaryArithmetic( + binaryOp: BinaryArithmetic, + decimalInfer: (RelDataType, RelDataType, LogicalType) => TypeInformation[_], + nonDecimalInfer: LogicalType => TypeInformation[_] + ): TypeInformation[_] = { + val leftType = fromTypeInfoToLogicalType(binaryOp.left.resultType) + val rightType = fromTypeInfoToLogicalType(binaryOp.right.resultType) + TypeCoercion.widerTypeOf(leftType, rightType) match { + case Some(t: DecimalType) => + val leftRelDataType = typeFactory.createFieldTypeFromLogicalType(leftType) + val rightRelDataType = typeFactory.createFieldTypeFromLogicalType(rightType) + decimalInfer(leftRelDataType, rightRelDataType, t) + case Some(t) => nonDecimalInfer(t) + case None => throw new TableException("This will not happen here!") + } + } + + /** + * Infer resultType of [[Round]] expression. + * The decimal type inference keeps consistent with Calcite + * [[org.apache.flink.table.calcite.type.FlinkReturnTypes]].ROUND_FUNCTION_NULLABLE + * + * @param round round Expression + * @return result type + */ + def inferRound(round: Round): TypeInformation[_] = { + val numType = round.left.resultType + numType match { + case _: DecimalTypeInfo | _: BigDecimalTypeInfo => + val lenValue = round.right match { + case Literal(v: Int, BasicTypeInfo.INT_TYPE_INFO) => v + case _ => throw new TableException("This will not happen here!") + } + val numLogicalType = fromTypeInfoToLogicalType(numType) + val numRelDataType = typeFactory.createFieldTypeFromLogicalType(numLogicalType) + val p = numRelDataType.getPrecision + val s = numRelDataType.getScale + val dt = FlinkTypeSystem.inferRoundType(p, s, lenValue) + fromLogicalTypeToTypeInfo(dt) + case t => t + } + } + + /** + * Infer resultType of [[Floor]] expression. + * The decimal type inference keeps consistent with Calcite + * [[org.apache.calcite.sql.type.ReturnTypes]].ARG0_OR_EXACT_NO_SCALE + * + * @param floor floor Expression + * @return result type + */ + def inferFloor(floor: Floor): TypeInformation[_] = getArg0OrExactNoScale(floor) + + /** + * Infer resultType of [[Ceil]] expression. + * The decimal type inference keeps consistent with Calcite + * [[org.apache.calcite.sql.type.ReturnTypes]].ARG0_OR_EXACT_NO_SCALE + * + * @param ceil ceil Expression + * @return result type + */ + def inferCeil(ceil: Ceil): TypeInformation[_] = getArg0OrExactNoScale(ceil) + + private def getArg0OrExactNoScale(op: UnaryExpression) = { + val childType = op.child.resultType + childType match { + case t: DecimalTypeInfo => DecimalTypeInfo.of(t.precision(), 0) + case t: BigDecimalTypeInfo => BigDecimalTypeInfo.of(t.precision(), 0) + case _ => childType + } + } + +} diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/arithmetic.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/arithmetic.scala index 726d9ff..20a4ba2 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/arithmetic.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/arithmetic.scala @@ -17,10 +17,10 @@ */ package org.apache.flink.table.expressions -import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation} +import org.apache.flink.api.common.typeinfo.TypeInformation import org.apache.flink.table.functions.sql.FlinkSqlOperatorTable import org.apache.flink.table.types.TypeInfoLogicalTypeConverter.{fromLogicalTypeToTypeInfo, fromTypeInfoToLogicalType} -import org.apache.flink.table.typeutils.{DecimalTypeInfo, TypeCoercion} +import org.apache.flink.table.typeutils.TypeCoercion import org.apache.flink.table.typeutils.TypeInfoCheckUtils._ import org.apache.flink.table.validate._ @@ -71,6 +71,10 @@ case class Plus(left: PlannerExpression, right: PlannerExpression) extends Binar s"but was '$left' : '${left.resultType}' and '$right' : '${right.resultType}'.") } } + + override private[flink] def resultType: TypeInformation[_] = { + ReturnTypeInference.inferPlus(this) + } } case class UnaryMinus(child: PlannerExpression) extends UnaryExpression { @@ -111,6 +115,10 @@ case class Minus(left: PlannerExpression, right: PlannerExpression) extends Bina s"but was '$left' : '${left.resultType}' and '$right' : '${right.resultType}'.") } } + + override private[flink] def resultType: TypeInformation[_] = { + ReturnTypeInference.inferMinus(this) + } } case class Div(left: PlannerExpression, right: PlannerExpression) extends BinaryArithmetic { @@ -118,17 +126,20 @@ case class Div(left: PlannerExpression, right: PlannerExpression) extends Binary private[flink] val sqlOperator = FlinkSqlOperatorTable.DIVIDE - override private[flink] def resultType: TypeInformation[_] = - super.resultType match { - case dt: DecimalTypeInfo => dt - case _ => BasicTypeInfo.DOUBLE_TYPE_INFO - } + override private[flink] def resultType: TypeInformation[_] = { + ReturnTypeInference.inferDiv(this) + } + } case class Mul(left: PlannerExpression, right: PlannerExpression) extends BinaryArithmetic { override def toString = s"($left * $right)" private[flink] val sqlOperator = FlinkSqlOperatorTable.MULTIPLY + + override private[flink] def resultType: TypeInformation[_] = { + ReturnTypeInference.inferMul(this) + } } case class Mod(left: PlannerExpression, right: PlannerExpression) extends BinaryArithmetic { diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/mathExpressions.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/mathExpressions.scala index 7c9d3fd..c28d2b8 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/mathExpressions.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/expressions/mathExpressions.scala @@ -32,7 +32,9 @@ case class Abs(child: PlannerExpression) extends UnaryExpression { } case class Ceil(child: PlannerExpression) extends UnaryExpression { - override private[flink] def resultType: TypeInformation[_] = LONG_TYPE_INFO + override private[flink] def resultType: TypeInformation[_] = { + ReturnTypeInference.inferCeil(this) + } override private[flink] def validateInput(): ValidationResult = TypeInfoCheckUtils.assertNumericExpr(child.resultType, "Ceil") @@ -50,7 +52,9 @@ case class Exp(child: PlannerExpression) extends UnaryExpression with InputTypeS case class Floor(child: PlannerExpression) extends UnaryExpression { - override private[flink] def resultType: TypeInformation[_] = LONG_TYPE_INFO + override private[flink] def resultType: TypeInformation[_] = { + ReturnTypeInference.inferFloor(this) + } override private[flink] def validateInput(): ValidationResult = TypeInfoCheckUtils.assertNumericExpr(child.resultType, "Floor") @@ -258,7 +262,9 @@ case class Sign(child: PlannerExpression) extends UnaryExpression { case class Round(left: PlannerExpression, right: PlannerExpression) extends BinaryExpression { - override private[flink] def resultType: TypeInformation[_] = left.resultType + override private[flink] def resultType: TypeInformation[_] = { + ReturnTypeInference.inferRound(this) + } override private[flink] def validateInput(): ValidationResult = { if (!TypeInfoCheckUtils.isInteger(right.resultType)) { diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/DecimalITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/DecimalITCase.scala index 0fc7e2e..137d277 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/DecimalITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/DecimalITCase.scala @@ -28,7 +28,7 @@ import org.apache.flink.table.types.TypeInfoLogicalTypeConverter.fromLogicalType import org.apache.flink.table.types.logical.{DecimalType, LogicalType} import org.apache.flink.types.Row -import org.junit.{Assert, Ignore, Test} +import org.junit.{Assert, Test} import java.math.{BigDecimal => JBigDecimal} @@ -591,7 +591,6 @@ class DecimalITCase extends BatchTestBase { s1r(null)) } - @Ignore @Test def testAggMinMaxCount(): Unit = { @@ -862,7 +861,6 @@ class DecimalITCase extends BatchTestBase { s1r(1L)) } - @Ignore @Test def testGroupBy(): Unit = { checkQuery1( @@ -896,7 +894,6 @@ class DecimalITCase extends BatchTestBase { s1r(d"100.000", null, null)) } - @Ignore @Test def testAggAvgGroupBy(): Unit = { diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/DecimalITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/table/DecimalITCase.scala similarity index 59% copy from flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/DecimalITCase.scala copy to flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/table/DecimalITCase.scala index 0fc7e2e..8bb6054 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/sql/DecimalITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/runtime/batch/table/DecimalITCase.scala @@ -16,87 +16,61 @@ * limitations under the License. */ -package org.apache.flink.table.runtime.batch.sql +package org.apache.flink.table.runtime.batch.table import org.apache.flink.api.java.typeutils.RowTypeInfo -import org.apache.flink.table.api.{DataTypes, ExecutionConfigOptions} -import org.apache.flink.table.runtime.utils.BatchTestBase +import org.apache.flink.table.api.{DataTypes, ExecutionConfigOptions, Table} +import org.apache.flink.table.api.scala._ import org.apache.flink.table.runtime.utils.BatchTestBase.row +import org.apache.flink.table.runtime.utils.{BatchTableEnvUtil, BatchTestBase} import org.apache.flink.table.types.LogicalTypeDataTypeConverter.fromDataTypeToLogicalType import org.apache.flink.table.types.PlannerTypeUtils.isInteroperable import org.apache.flink.table.types.TypeInfoLogicalTypeConverter.fromLogicalTypeToTypeInfo import org.apache.flink.table.types.logical.{DecimalType, LogicalType} import org.apache.flink.types.Row -import org.junit.{Assert, Ignore, Test} +import org.junit.{Assert, Test} import java.math.{BigDecimal => JBigDecimal} import scala.collection.Seq /** - * Conformance test of SQL type Decimal(p,s). + * Conformance test of TableApi type Decimal(p,s). * Served also as documentation of our Decimal behaviors. */ class DecimalITCase extends BatchTestBase { - private case class Coll(colTypes: Seq[LogicalType], rows: Seq[Row]) - - private var globalTableId = 0 - private def checkQueryX( - tables: Seq[Coll], - query: String, - expected: Coll, - isSorted: Boolean = false) - : Unit = { - - var tableId = 0 - var queryX = query - tables.foreach{ table => - tableId += 1 - globalTableId += 1 - val tableName = "Table" + tableId - val newTableName = tableName + "_" + globalTableId - val rowTypeInfo = new RowTypeInfo(table.colTypes.toArray.map(fromLogicalTypeToTypeInfo): _*) - val fieldNames = rowTypeInfo.getFieldNames.mkString(",") - registerCollection(newTableName, table.rows, rowTypeInfo, fieldNames) - queryX = queryX.replace(tableName, newTableName) - } + private def checkQuery( + sourceColTypes: Seq[LogicalType], + sourceRows: Seq[Row], + tableTransfer: Table => Table, + expectedColTypes: Seq[LogicalType], + expectedRows: Seq[Row], + isSorted: Boolean = false): Unit = { + val rowTypeInfo = new RowTypeInfo(sourceColTypes.toArray.map(fromLogicalTypeToTypeInfo): _*) + val fieldNames = rowTypeInfo.getFieldNames.mkString(",") + val t = BatchTableEnvUtil.fromCollection(tEnv, sourceRows, rowTypeInfo, fieldNames) // check result schema - val resultTable = parseQuery(queryX) - val ts1 = expected.colTypes + val resultTable = tableTransfer(t) val ts2 = resultTable.getSchema.getFieldDataTypes.map(fromDataTypeToLogicalType) - Assert.assertEquals(ts1.length, ts2.length) + Assert.assertEquals(expectedColTypes.length, ts2.length) - Assert.assertTrue(ts1.zip(ts2).forall { + Assert.assertTrue(expectedColTypes.zip(ts2).forall { case (t1, t2) => isInteroperable(t1, t2) }) def prepareResult(isSorted: Boolean, seq: Seq[Row]) = { if (!isSorted) seq.map(_.toString).sortBy(s => s) else seq.map(_.toString) } + val resultRows = executeQuery(resultTable) Assert.assertEquals( - prepareResult(isSorted, expected.rows), + prepareResult(isSorted, expectedRows), prepareResult(isSorted, resultRows)) } - private def checkQuery1( - sourceColTypes: Seq[LogicalType], - sourceRows: Seq[Row], - query: String, - expectedColTypes: Seq[LogicalType], - expectedRows: Seq[Row], - isSorted: Boolean = false) - : Unit = { - checkQueryX( - Seq(Coll(sourceColTypes, sourceRows)), - query, - Coll(expectedColTypes, expectedRows), - isSorted) - } - // a Seq of one Row private def s1r(args: Any*): Seq[Row] = Seq(row(args: _*)) @@ -122,9 +96,13 @@ class DecimalITCase extends BatchTestBase { private def DECIMAL = (p: Int, s: Int) => new DecimalType(p, s) private def BOOL = DataTypes.BOOLEAN.getLogicalType + private def INT = DataTypes.INT.getLogicalType + private def LONG = DataTypes.BIGINT.getLogicalType + private def DOUBLE = DataTypes.DOUBLE.getLogicalType + private def STRING = DataTypes.STRING.getLogicalType // d"xxx" => new BigDecimal("xxx") @@ -145,87 +123,62 @@ class DecimalITCase extends BatchTestBase { def testDataSource(): Unit = { // the most basic case - checkQuery1( + + checkQuery( Seq(DECIMAL(10, 0), DECIMAL(7, 2)), s1r(d"123", d"123.45"), - "select * from Table1", + table => table.select('*), Seq(DECIMAL(10, 0), DECIMAL(7, 2)), s1r(d"123", d"123.45")) // data from source are rounded to their declared scale before entering next step - checkQuery1( + checkQuery( Seq(DECIMAL(7, 2)), s1r(d"100.004"), - "select f0, f0+f0 from Table1", // 100.00+100.00 + table => table.select('f0, 'f0 + 'f0), // 100.00+100.00 Seq(DECIMAL(7, 2), DECIMAL(8, 2)), - s1r(d"100.00", d"200.00")) // not 200.008=>200.01 + s1r(d"100.00", d"200.00")) // not 200.008=>200.01 // trailing zeros are padded to the scale - checkQuery1( + checkQuery( Seq(DECIMAL(7, 2)), s1r(d"100.1"), - "select f0, f0+f0 from Table1", // 100.00+100.00 + table => table.select('f0, 'f0 + 'f0), // 100.00+100.00 Seq(DECIMAL(7, 2), DECIMAL(8, 2)), s1r(d"100.10", d"200.20")) // source data is within precision after rounding - checkQuery1( + checkQuery( Seq(DECIMAL(5, 2)), s1r(d"100.0040"), // p=7 => rounding => p=5 - "select f0, f0+f0 from Table1", + table => table.select('f0, 'f0 + 'f0), // 100.00+100.00 Seq(DECIMAL(5, 2), DECIMAL(6, 2)), s1r(d"100.00", d"200.00")) // source data overflows over precision (after rounding) - checkQuery1( + checkQuery( Seq(DECIMAL(2, 0)), s1r(d"123"), - "select * from Table1", + table => table.select('*), Seq(DECIMAL(2, 0)), s1r(null)) - checkQuery1( + checkQuery( Seq(DECIMAL(4, 2)), s1r(d"123.0000"), - "select * from Table1", + table => table.select('*), Seq(DECIMAL(4, 2)), s1r(null)) } @Test - def testLiterals(): Unit = { - - checkQuery1( - Seq(DECIMAL(1,0)), - s1r(d"0"), - "select 12, 12.3, 12.34 from Table1", - Seq(INT, DECIMAL(3, 1), DECIMAL(4, 2)), - s1r(12, d"12.3", d"12.34")) - - checkQuery1( - Seq(DECIMAL(1,0)), - s1r(d"0"), - "select 123456789012345678901234567890.12345678 from Table1", - Seq(DECIMAL(38, 8)), - s1r(d"123456789012345678901234567890.12345678")) - - expectOverflow(()=> - checkQuery1( - Seq(DECIMAL(1,0)), - s1r(d"0"), - "select 123456789012345678901234567890.123456789 from Table1", - Seq(DECIMAL(38, 9)), - s1r(d"123456789012345678901234567890.123456789"))) - } - - @Test def testUnaryPlusMinus(): Unit = { - checkQuery1( + checkQuery( Seq(DECIMAL(10, 0), DECIMAL(7, 2)), s1r(d"123", d"123.45"), - "select +f0, -f1, -((+f0)-(-f1)) from Table1", - Seq(DECIMAL(10, 0), DECIMAL(7, 2), DECIMAL(13,2)), + table => table.select( + 'f0, - 'f1, - (( + 'f0) - ( - 'f1))), + Seq(DECIMAL(10, 0), DECIMAL(7, 2), DECIMAL(13, 2)), s1r(d"123", d"-123.45", d"-246.45")) } @@ -235,19 +188,19 @@ class DecimalITCase extends BatchTestBase { // see calcite ReturnTypes.DECIMAL_SUM // s = max(s1,s2), p-s = max(p1-s1, p2-s2) + 1 // p then is capped at 38 - checkQuery1( + checkQuery( Seq(DECIMAL(10, 2), DECIMAL(10, 4)), s1r(d"100.12", d"200.1234"), - "select f0+f1, f0-f1 from Table1", + table => table.select('f0 + 'f1, 'f0 - 'f1), Seq(DECIMAL(13, 4), DECIMAL(13, 4)), s1r(d"300.2434", d"-100.0034")) // INT => DECIMAL(10,0) // approximate + exact => approximate - checkQuery1( + checkQuery( Seq(DECIMAL(10, 2), INT, DOUBLE), s1r(d"100.00", 200, 3.14), - "select f0+f1, f1+f0, f0+f2, f2+f0 from Table1", + table => table.select('f0 + 'f1, 'f1 + 'f0, 'f0 + 'f2, 'f2 + 'f0), Seq(DECIMAL(13, 2), DECIMAL(13, 2), DOUBLE, DOUBLE), s1r(d"300.00", d"300.00", d"103.14", d"103.14")) @@ -257,31 +210,31 @@ class DecimalITCase extends BatchTestBase { // (38,10)+(38,28)=>(57,28)=>(38,28) // T-SQL -- scale may be reduced to keep the integral part. approximation may occur // (38,10)+(38,28)=>(57,28)=>(38,9) - checkQuery1( + checkQuery( Seq(DECIMAL(38, 10), DECIMAL(38, 28)), s1r(d"100.0123456789", d"200.0123456789012345678901234567"), - "select f0+f1, f0-f1 from Table1", + table => table.select('f0 + 'f1, 'f0 - 'f1), Seq(DECIMAL(38, 28), DECIMAL(38, 28)), s1r(d"300.0246913578012345678901234567", d"-100.0000000000012345678901234567")) - checkQuery1( + checkQuery( Seq(DECIMAL(38, 10), DECIMAL(38, 28)), s1r(d"1e10", d"0"), - "select f1+f0, f1-f0 from Table1", + table => table.select('f1 + 'f0, 'f1 - 'f0 ), Seq(DECIMAL(38, 28), DECIMAL(38, 28)), // 10 digits integral part s1r(null, null)) - checkQuery1( + checkQuery( Seq(DECIMAL(38, 0)), s1r(d"5e37"), - "select f0+f0 from Table1", + table => table.select('f0 + 'f0 ), Seq(DECIMAL(38, 0)), s1r(null)) // requires 39 digits - checkQuery1( + checkQuery( Seq(DECIMAL(38, 0), DECIMAL(38, 0)), s1r(d"5e37", d"5e37"), - "select f0+f0-f1 from Table1", // overflows in subexpression + table => table.select('f0 + 'f0 -'f1 ), // overflows in subexpression Seq(DECIMAL(38, 0)), s1r(null)) } @@ -293,207 +246,166 @@ class DecimalITCase extends BatchTestBase { // s = s1+s2, p = p1+p2 // both p&s are capped at 38 // if s>38, result is rounded to s=38, and the integral part can only be zero - checkQuery1( + checkQuery( Seq(DECIMAL(5, 2), DECIMAL(10, 4)), s1r(d"1.00", d"2.0000"), - "select f0*f0, f0*f1 from Table1", + table => table.select('f0*'f0, 'f0*'f1 ), Seq(DECIMAL(10, 4), DECIMAL(15, 6)), s1r(d"1.0000", d"2.000000")) // INT => DECIMAL(10,0) // approximate * exact => approximate - checkQuery1( + checkQuery( Seq(DECIMAL(10, 2), INT, DOUBLE), s1r(d"1.00", 200, 3.14), - "select f0*f1, f1*f0, f0*f2, f2*f0 from Table1", + table => table.select('f0*'f1, 'f1*'f0, 'f0*'f2, 'f2*'f0 ), Seq(DECIMAL(20, 2), DECIMAL(20, 2), DOUBLE, DOUBLE), s1r(d"200.00", d"200.00", 3.14, 3.14)) // precision is capped at 38; scale will not be reduced (unless over 38) // similar to plus&minus, and calcite behavior is different from T-SQL. - checkQuery1( + checkQuery( Seq(DECIMAL(30, 6), DECIMAL(30, 10)), s1r(d"1", d"2"), - "select f0*f0, f0*f1 from Table1", + table => table.select('f0*'f0, 'f0*'f1 ), Seq(DECIMAL(38, 12), DECIMAL(38, 16)), s1r(d"1${12}", d"2${16}")) - checkQuery1( + checkQuery( Seq(DECIMAL(30, 20)), s1r(d"0.1"), - "select f0*f0 from Table1", - Seq(DECIMAL(38, 38)), // (60,40)=>(38,38) + table => table.select('f0*'f0 ), + Seq(DECIMAL(38, 38)), // (60,40)=>(38,38) s1r(d"0.01${38}")) // scalastyle:off // we don't have this ridiculous behavior: - // https://blogs.msdn.microsoft.com/sqlprogrammability/2006/03/29/multiplication-and-division-with-numerics/ + // https://blogs.msdn.microsoft + // .com/sqlprogrammability/2006/03/29/multiplication-and-division-with-numerics/ // scalastyle:on - checkQuery1( + checkQuery( Seq(DECIMAL(38, 10), DECIMAL(38, 10)), s1r(d"0.0000006", d"1.0"), - "select f0*f1 from Table1", + table => table.select('f0*'f1 ), Seq(DECIMAL(38, 20)), s1r(d"0.0000006${20}")) // result overflow - checkQuery1( + checkQuery( Seq(DECIMAL(38, 0)), s1r(d"1e19"), - "select f0*f0 from Table1", + table => table.select('f0*'f0 ), Seq(DECIMAL(38, 0)), s1r(null)) - checkQuery1( + checkQuery( Seq(DECIMAL(30, 20)), s1r(d"1.0"), - "select f0*f0 from Table1", - Seq(DECIMAL(38, 38)), // (60,40)=>(38,38), no space for integral part + table => table.select('f0*'f0 ), + Seq(DECIMAL(38, 38)), // (60,40)=>(38,38), no space for integral part s1r(null)) } @Test def testDivide(): Unit = { - // the default impl of Calcite apparently borrows from T-SQL, but differs in details. - // Flink overrides it to follow T-SQL exactly. See FlinkTypeFactory.createDecimalQuotient() - checkQuery1( // test 1/3 in different scales +// // the default impl of Calcite apparently borrows from T-SQL, but differs in details. +// // Flink overrides it to follow T-SQL exactly. See FlinkTypeFactory.createDecimalQuotient() + checkQuery( // test 1/3 in different scales Seq(DECIMAL(20, 2), DECIMAL(2, 1), DECIMAL(4, 3), DECIMAL(20, 10), DECIMAL(20, 16)), s1r(d"1.00", d"3", d"3", d"3", d"3"), - "select f0/f1, f0/f2, f0/f3, f0/f4 from Table1", + table => table.select('f0/'f1, 'f0/'f2, 'f0/'f3, 'f0/'f4 ), Seq(DECIMAL(25, 6), DECIMAL(28, 7), DECIMAL(38, 10), DECIMAL(38, 6)), s1r(d"0.333333", d"0.3333333", d"0.3333333333", d"0.333333")) // INT => DECIMAL(10,0) // approximate / exact => approximate - checkQuery1( + checkQuery( Seq(DECIMAL(10, 2), INT, DOUBLE), s1r(d"1.00", 2, 3.0), - "select f0/f1, f1/f0, f0/f2, f2/f0 from Table1", + table => table.select('f0/'f1, 'f1/'f0, 'f0/'f2, 'f2/'f0 ), Seq(DECIMAL(21, 13), DECIMAL(23, 11), DOUBLE, DOUBLE), - s1r(d"0.5${13}", d"2${11}", 1.0/3.0, 3.0/1.0)) + s1r(d"0.5${13}", d"2${11}", 1.0 / 3.0, 3.0 / 1.0)) // result overflow, because result type integral part is reduced - checkQuery1( + checkQuery( Seq(DECIMAL(30, 0), DECIMAL(30, 20)), s1r(d"1e20", d"1e-15"), - "select f0/f1 from Table1", + table => table.select('f0/'f1 ), Seq(DECIMAL(38, 6)), s1r(null)) } @Test def testMod(): Unit = { - - // MOD(Exact1, Exact2) => Exact2 - checkQuery1( - Seq(DECIMAL(10, 2), DECIMAL(10, 4), INT), - s1r(d"3.00", d"5.0000", 7), - "select mod(f0,f1), mod(f1,f0), mod(f0,f2), mod(f2,f0) from Table1", - Seq(DECIMAL(10, 4), DECIMAL(10, 2), INT, DECIMAL(10, 2)), - s1r(d"3.0000", d"2.00", 3, d"1.00")) - // signs. consistent with Java's % operator. - checkQuery1( + checkQuery( Seq(DECIMAL(1, 0), DECIMAL(1, 0)), s1r(d"3", d"5"), - "select mod(f0,f1), mod(-f0,f1), mod(f0,-f1), mod(-f0,-f1) from Table1", + table => table.select('f0 % 'f1, (-'f0) % 'f1,'f0 % (-'f1), (-'f0) % (-'f1)), Seq(DECIMAL(1, 0), DECIMAL(1, 0), DECIMAL(1, 0), DECIMAL(1, 0)), - s1r(3%5, (-3)%5, 3%(-5), (-3)%(-5))) + s1r(3 % 5, (-3) % 5, 3 % (-5), (-3) % (-5))) // rounding in case s1>s2. note that SQL2003 requires s1=s2=0. // (In T-SQL, s2 is expanded to s1, so that there's no rounding.) - checkQuery1( + checkQuery( Seq(DECIMAL(10, 4), DECIMAL(10, 2)), s1r(d"3.1234", d"5"), - "select mod(f0,f1) from Table1", + table => table.select('f0 % 'f1), Seq(DECIMAL(10, 2)), s1r(d"3.12")) } - @Test - def testDiv(): Unit = { - - // see DivCallGen - checkQuery1( - Seq(DECIMAL(7, 0), INT), - s1r(d"7", 2), - "select div(f0, f1), div(100*f1, f0) from Table1", - Seq(DECIMAL(7, 0), DECIMAL(10, 0)), - s1r(3, 200 / 7)) - - checkQuery1( - Seq(DECIMAL(10, 1), DECIMAL(10, 3)), - s1r(d"7.9", d"2.009"), - "select div(f0, f1), div(100*f1, f0) from Table1", - Seq(DECIMAL(12, 0), DECIMAL(18, 0)), - s1r(3, 2009 / 79)) - } - @Test // functions that treat Decimal as exact value def testExactFunctions(): Unit = { - checkQuery1( - Seq(DECIMAL(10, 2), DECIMAL(10, 2)), - s1r(d"3.14", d"2.17"), - "select if(f0>f1, f0, f1) from Table1", - Seq(DECIMAL(10, 2)), - s1r(d"3.14")) - - checkQuery1( - Seq(DECIMAL(10, 2)), - s1r(d"3.14"), - "select abs(f0), abs(-f0) from Table1", - Seq(DECIMAL(10, 2), DECIMAL(10, 2)), - s1r(d"3.14", d"3.14")) - - checkQuery1( + checkQuery( Seq(DECIMAL(10, 2)), s1r(d"3.14"), - "select floor(f0), ceil(f0) from Table1", + table => table.select('f0.floor, 'f0.ceil), Seq(DECIMAL(10, 0), DECIMAL(10, 0)), s1r(d"3", d"4")) // calcite: SIGN(Decimal(p,s))=>Decimal(p,s) - checkQuery1( + checkQuery( Seq(DECIMAL(10, 2)), s1r(d"3.14"), - "select sign(f0), sign(-f0), sign(f0-f0) from Table1", + table => table.select('f0.sign, (-'f0).sign, ('f0 - 'f0).sign ), Seq(DECIMAL(10, 2), DECIMAL(10, 2), DECIMAL(11, 2)), s1r(d"1.00", d"-1.00", d"0.00")) // ROUND(Decimal(p,s)[,INT]) - checkQuery1( + checkQuery( Seq(DECIMAL(10, 3)), s1r(d"646.646"), - "select round(f0), round(f0, 0) from Table1", + table => table.select('f0.round(0), 'f0.round(0)), Seq(DECIMAL(8, 0), DECIMAL(8, 0)), s1r(d"647", d"647")) - checkQuery1( + checkQuery( Seq(DECIMAL(10, 3)), s1r(d"646.646"), - "select round(f0,1), round(f0,2), round(f0,3), round(f0,4) from Table1", + table => table.select('f0.round(1), 'f0.round(2), 'f0.round(3), 'f0.round(4) ), Seq(DECIMAL(9, 1), DECIMAL(10, 2), DECIMAL(10, 3), DECIMAL(10, 3)), s1r(d"646.6", d"646.65", d"646.646", d"646.646")) - checkQuery1( + checkQuery( Seq(DECIMAL(10, 3)), s1r(d"646.646"), - "select round(f0,-1), round(f0,-2), round(f0,-3), round(f0,-4) from Table1", + table => table.select('f0.round(-1), 'f0.round(-2), 'f0.round(-3), 'f0.round(-4) ), Seq(DECIMAL(8, 0), DECIMAL(8, 0), DECIMAL(8, 0), DECIMAL(8, 0)), s1r(d"650", d"600", d"1000", d"0")) - checkQuery1( + checkQuery( Seq(DECIMAL(4, 2)), s1r(d"99.99"), - "select round(f0,1), round(-f0,1), round(f0,-1), round(-f0,-1) from Table1", + table => table.select('f0.round(1), (-'f0).round(1), 'f0.round(-1), (-'f0).round(-1) ), Seq(DECIMAL(4, 1), DECIMAL(4, 1), DECIMAL(3, 0), DECIMAL(3, 0)), s1r(d"100.0", d"-100.0", d"100", d"-100")) - checkQuery1( + checkQuery( Seq(DECIMAL(38, 0)), s1r(d"1E38".subtract(d"1")), - "select round(f0,-1) from Table1", + table => table.select('f0.round(-1) ), Seq(DECIMAL(38, 0)), s1r(null)) } @@ -503,52 +415,24 @@ class DecimalITCase extends BatchTestBase { import java.lang.Math._ - checkQuery1( - Seq(DECIMAL(10, 2)), - s1r(d"3.14"), - "select log10(f0), ln(f0), log(f0), log2(f0) from Table1", - Seq(DOUBLE, DOUBLE, DOUBLE, DOUBLE), - s1r(log10(3.14), Math.log(3.14), Math.log(3.14), Math.log(3.14)/Math.log(2.0))) - - checkQuery1( - Seq(DECIMAL(10, 2), DOUBLE), - s1r(d"3.14", 3.14), - "select log(f0,f0), log(f0,f1), log(f1,f0) from Table1", - Seq(DOUBLE, DOUBLE, DOUBLE), - s1r(1.0, 1.0, 1.0)) - - checkQuery1( - Seq(DECIMAL(10, 2), DOUBLE), - s1r(d"3.14", 0.3), - "select power(f0,f0), power(f0,f1), power(f1,f0), sqrt(f0) from Table1", - Seq(DOUBLE, DOUBLE, DOUBLE, DOUBLE), - s1r(pow(3.14, 3.14), pow(3.14, 0.3), pow(0.3, 3.14), pow(3.14, 0.5))) - - checkQuery1( - Seq(DECIMAL(10, 2), DOUBLE), - s1r(d"3.14", 0.3), - "select exp(f0), exp(f1) from Table1", - Seq(DOUBLE, DOUBLE), - s1r(exp(3.14), exp(0.3))) - - checkQuery1( + checkQuery( Seq(DECIMAL(10, 2)), s1r(d"0.12"), - "select sin(f0), cos(f0), tan(f0), cot(f0) from Table1", + table => table.select('f0.sin, 'f0.cos, 'f0.tan, 'f0.cot ), Seq(DOUBLE, DOUBLE, DOUBLE, DOUBLE), - s1r(sin(0.12), cos(0.12), tan(0.12), 1.0/tan(0.12))) + s1r(sin(0.12), cos(0.12), tan(0.12), 1.0 / tan(0.12))) - checkQuery1( + checkQuery( Seq(DECIMAL(10, 2)), s1r(d"0.12"), - "select asin(f0), acos(f0), atan(f0) from Table1", + table => table.select('f0.asin, 'f0.acos, 'f0.atan ), Seq(DOUBLE, DOUBLE, DOUBLE), s1r(asin(0.12), acos(0.12), atan(0.12))) - checkQuery1( + checkQuery( Seq(DECIMAL(10, 2)), s1r(d"0.12"), - "select degrees(f0), radians(f0) from Table1", + table => table.select('f0.degrees, 'f0.radians), Seq(DOUBLE, DOUBLE), s1r(toDegrees(0.12), toRadians(0.12))) } @@ -557,17 +441,17 @@ class DecimalITCase extends BatchTestBase { def testAggSum(): Unit = { // SUM(Decimal(p,s))=>Decimal(38,s) - checkQuery1( + checkQuery( Seq(DECIMAL(6, 3)), (0 until 100).map(_ => row(d"1.000")), - "select sum(f0) from Table1", + table => table.select('f0.sum ), Seq(DECIMAL(38, 3)), s1r(d"100.000")) - checkQuery1( + checkQuery( Seq(DECIMAL(37, 0)), (0 until 100).map(_ => row(d"1e36")), - "select sum(f0) from Table1", + table => table.select('f0.sum ), Seq(DECIMAL(38, 0)), s1r(null)) } @@ -576,107 +460,87 @@ class DecimalITCase extends BatchTestBase { def testAggAvg(): Unit = { // AVG(Decimal(p,s)) => Decimal(38,s)/Decimal(20,0) => Decimal(38, max(s,6)) - checkQuery1( + checkQuery( Seq(DECIMAL(6, 3), DECIMAL(20, 10)), (0 until 100).map(_ => row(d"100.000", d"1${10}")), - "select avg(f0), avg(f1) from Table1", + table => table.select('f0.avg, 'f1.avg ), Seq(DECIMAL(38, 6), DECIMAL(38, 10)), s1r(d"100.000000", d"1${10}")) - checkQuery1( + checkQuery( Seq(DECIMAL(37, 0)), (0 until 100).map(_ => row(d"1e36")), - "select avg(f0) from Table1", + table => table.select('f0.avg), Seq(DECIMAL(38, 6)), s1r(null)) } - @Ignore @Test def testAggMinMaxCount(): Unit = { // MIN/MAX(T) => T - checkQuery1( + checkQuery( Seq(DECIMAL(6, 3)), (10 to 90).map(i => row(java.math.BigDecimal.valueOf(i))), - "select min(f0), max(f0), count(f0) from Table1", + table => table.select('f0.min, 'f0.max, 'f0.count ), Seq(DECIMAL(6, 3), DECIMAL(6, 3), LONG), s1r(d"10.000", d"90.000", 81L)) } @Test - def testCaseWhen(): Unit = { - - // result type: SQL2003 $9.23, calcite RelDataTypeFactory.leastRestrictive() - checkQuery1( - Seq(DECIMAL(8, 4), DECIMAL(10, 2)), - s1r(d"0.0001", d"0.01"), - "select case f0 when 0 then f0 else f1 end from Table1", - Seq(DECIMAL(12, 4)), - s1r(d"0.0100")) - - checkQuery1( - Seq(DECIMAL(8, 4), INT), - s1r(d"0.0001", 1), - "select case f0 when 0 then f0 else f1 end from Table1", - Seq(DECIMAL(14, 4)), - s1r(d"1.0000")) - - checkQuery1( - Seq(DECIMAL(8, 4), DOUBLE), - s1r(d"0.0001", 3.14), - "select case f0 when 0 then f1 else f0 end from Table1", - Seq(DOUBLE), - s1r(d"0.0001".doubleValue())) - } - - @Test def testCast(): Unit = { // String, numeric/Decimal => Decimal - checkQuery1( + checkQuery( Seq(DECIMAL(8, 2), INT, DOUBLE, STRING), s1r(d"3.14", 3, 3.14, "3.14"), - "select cast(f0 as Decimal(8,4)), cast(f1 as Decimal(8,4)), " + - "cast(f2 as Decimal(8,4)), cast(f3 as Decimal(8,4)) from Table1", + table => table.select('f0.cast(DataTypes.DECIMAL(8,4)), + 'f1.cast(DataTypes.DECIMAL(8,4)), + 'f2.cast(DataTypes.DECIMAL(8,4)), + 'f3.cast(DataTypes.DECIMAL(8,4)) ), Seq(DECIMAL(8, 4), DECIMAL(8, 4), DECIMAL(8, 4), DECIMAL(8, 4)), s1r(d"3.1400", d"3.0000", d"3.1400", d"3.1400")) // round up - checkQuery1( + checkQuery( Seq(DECIMAL(8, 2), DOUBLE, STRING), s1r(d"3.15", 3.15, "3.15"), - "select cast(f0 as Decimal(8,1)), cast(f1 as Decimal(8,1)), " + - "cast(f2 as Decimal(8,1)) from Table1", + table => table.select( + 'f0.cast(DataTypes.DECIMAL(8,1)), + 'f1.cast(DataTypes.DECIMAL(8,1)), + 'f2.cast(DataTypes.DECIMAL(8,1))), Seq(DECIMAL(8, 1), DECIMAL(8, 1), DECIMAL(8, 1)), s1r(d"3.2", d"3.2", d"3.2")) - checkQuery1( + checkQuery( Seq(DECIMAL(4, 2)), s1r(d"13.14"), - "select cast(f0 as Decimal(3,2)) from Table1", + table => table.select('f0.cast(DataTypes.DECIMAL(3,2)) ), Seq(DECIMAL(3, 2)), s1r(null)) - checkQuery1( + checkQuery( Seq(STRING), s1r("13.14"), - "select cast(f0 as Decimal(3,2)) from Table1", + table => table.select('f0.cast(DataTypes.DECIMAL(3,2)) ), Seq(DECIMAL(3, 2)), s1r(null)) // Decimal => String, numeric - checkQuery1( + checkQuery( Seq(DECIMAL(4, 2)), s1r(d"1.99"), - "select cast(f0 as VARCHAR(64)), cast(f0 as DOUBLE), cast(f0 as INT) from Table1", + table => table.select( + 'f0.cast(DataTypes.VARCHAR(64)), + 'f0.cast(DataTypes.DOUBLE), + 'f0.cast(DataTypes.INT)), Seq(STRING, DOUBLE, INT), s1r("1.99", 1.99, 1)) - checkQuery1( + checkQuery( Seq(DECIMAL(10, 0), DECIMAL(10, 0)), s1r(d"-2147483648", d"2147483647"), - "select cast(f0 as INT), cast(f1 as INT) from Table1", + table => table.select('f0.cast(DataTypes.INT), 'f1.cast(DataTypes.INT)), Seq(INT, INT), s1r(-2147483648, 2147483647)) } @@ -687,46 +551,31 @@ class DecimalITCase extends BatchTestBase { // expressions that test equality. // =, CASE, NULLIF, IN, IS DISTINCT FROM - checkQuery1( + checkQuery( Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE), s1r(d"1", d"1", 1, 1.0), - "select f0=f1, f0=f2, f0=f3, f1=f0, f2=f0, f3=f0 from Table1", + table => table.select('f0==='f1, 'f0==='f2, 'f0==='f3, 'f1==='f0, 'f2==='f0, 'f3==='f0 ), Seq(BOOL, BOOL, BOOL, BOOL, BOOL, BOOL), s1r(true, true, true, true, true, true)) - checkQuery1( + checkQuery( Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE), s1r(d"1", d"1", 1, 1.0), - "select f0 IN(f1), f0 IN(f2), f0 IN(f3), " + - "f1 IN(f0), f2 IN(f0), f3 IN(f0) from Table1", + table => table.select('f0.in('f1), 'f0.in('f2), 'f0.in('f3), + 'f1.in('f0), 'f2.in('f0), 'f3.in('f0) ), Seq(BOOL, BOOL, BOOL, BOOL, BOOL, BOOL), s1r(true, true, true, true, true, true)) - checkQuery1( - Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE), - s1r(d"1", d"1", 1, 1.0), - "select " + - "f0 IS DISTINCT FROM f1, f1 IS DISTINCT FROM f0, " + - "f0 IS DISTINCT FROM f2, f2 IS DISTINCT FROM f0, " + - "f0 IS DISTINCT FROM f3, f3 IS DISTINCT FROM f0 from Table1", - Seq(BOOL, BOOL, BOOL, BOOL, BOOL, BOOL), - s1r(false, false, false, false, false, false)) - - checkQuery1( - Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE), - s1r(d"1", d"1", 1, 1.0), - "select NULLIF(f0,f1), NULLIF(f0,f2), NULLIF(f0,f3)," + - "NULLIF(f1,f0), NULLIF(f2,f0), NULLIF(f3,f0) from Table1", - Seq(DECIMAL(8, 2), DECIMAL(8, 2), DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE), - s1r(null, null, null, null, null, null)) - - checkQuery1( + checkQuery( Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE), s1r(d"1", d"1", 1, 1.0), - "select " + - "case f0 when f1 then 1 else 0 end, case f1 when f0 then 1 else 0 end, " + - "case f0 when f2 then 1 else 0 end, case f2 when f0 then 1 else 0 end, " + - "case f0 when f3 then 1 else 0 end, case f3 when f0 then 1 else 0 end from Table1", + table => table.select( + ('f0 === 'f1) ? (1, 0), + ('f1 === 'f0) ?(1, 0), + ('f0 === 'f2) ? (1, 0), + ('f2 === 'f0) ? (1, 0), + ('f0 === 'f3) ? (1, 0), + ('f3 === 'f0) ? (1, 0)), Seq(INT, INT, INT, INT, INT, INT), s1r(1, 1, 1, 1, 1, 1)) } @@ -734,39 +583,45 @@ class DecimalITCase extends BatchTestBase { @Test def testComparison(): Unit = { - checkQuery1( + checkQuery( Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE), s1r(d"1", d"1", 1, 1.0), - "select f0<f1, f0<f2, f0<f3, f1<f0, f2<f0, f3<f0 from Table1", + table => table.select('f0<'f1, 'f0<'f2, 'f0<'f3, 'f1<'f0, 'f2<'f0, 'f3<'f0 ), Seq(BOOL, BOOL, BOOL, BOOL, BOOL, BOOL), s1r(false, false, false, false, false, false)) // no overflow during type conversion. // conceptually both operands are promoted to infinite precision before comparison. - checkQuery1( + checkQuery( Seq(DECIMAL(1, 0), DECIMAL(2, 0), INT, DOUBLE), s1r(d"1", d"99", 99, 99.0), - "select f0<f1, f0<f2, f0<f3, f1<f0, f2<f0, f3<f0 from Table1", + table => table.select('f0<'f1, 'f0<'f2, 'f0<'f3, 'f1<'f0, 'f2<'f0, 'f3<'f0 ), Seq(BOOL, BOOL, BOOL, BOOL, BOOL, BOOL), s1r(true, true, true, false, false, false)) - checkQuery1( + checkQuery( Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE), s1r(d"1", d"1", 1, 1.0), - "select " + - "f0 between f1 and 1, f1 between f0 and 1, " + - "f0 between f2 and 1, f2 between f0 and 1, " + - "f0 between f3 and 1, f3 between f0 and 1 from Table1", + table => table.select( + 'f0.between('f1, 1), + 'f1.between('f0, 1), + 'f0.between('f2, 1), + 'f2.between('f0, 1), + 'f0.between('f3, 1), + 'f3.between('f0, 1)), Seq(BOOL, BOOL, BOOL, BOOL, BOOL, BOOL), s1r(true, true, true, true, true, true)) - checkQuery1( + checkQuery( Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE), s1r(d"1", d"1", 1, 1.0), - "select " + - "f0 between 0 and f1, f1 between 0 and f0, " + - "f0 between 0 and f2, f2 between 0 and f0, " + - "f0 between 0 and f3, f3 between 0 and f0 from Table1", + table => table.select( + 'f0.between(0, 'f1), + 'f1.between(0, 'f0), + 'f0.between(0, 'f2), + 'f2.between(0, 'f0), + 'f0.between(0, 'f3), + 'f3.between(0, 'f0)), Seq(BOOL, BOOL, BOOL, BOOL, BOOL, BOOL), s1r(true, true, true, true, true, true)) } @@ -776,10 +631,10 @@ class DecimalITCase extends BatchTestBase { tEnv.getConfig.getConfiguration.setString( ExecutionConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin, NestedLoopJoin") - checkQuery1( + checkQuery( Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE), s1r(d"1", d"1", 1, 1.0), - "select count(*) from Table1 A, Table1 B where A.f0=B.f0", + table => table.as('a, 'b, 'c, 'd).join(table).where('a === 'f0).select(1.count), Seq(LONG), s1r(1L)) } @@ -789,10 +644,10 @@ class DecimalITCase extends BatchTestBase { tEnv.getConfig.getConfiguration.setString( ExecutionConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin, NestedLoopJoin") - checkQuery1( + checkQuery( Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE), s1r(d"1", d"1", 1, 1.0), - "select count(*) from Table1 A, Table1 B where A.f0=B.f1", + table => table.as('a, 'b, 'c, 'd).join(table).where('a === 'f1).select(1.count), Seq(LONG), s1r(1L)) } @@ -802,10 +657,10 @@ class DecimalITCase extends BatchTestBase { tEnv.getConfig.getConfiguration.setString( ExecutionConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin, NestedLoopJoin") - checkQuery1( + checkQuery( Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE), s1r(d"1", d"1", 1, 1.0), - "select count(*) from Table1 A, Table1 B where A.f1=B.f0", + table => table.as('a, 'b, 'c, 'd).join(table).where('b === 'f0).select(1.count), Seq(LONG), s1r(1L)) @@ -816,10 +671,10 @@ class DecimalITCase extends BatchTestBase { tEnv.getConfig.getConfiguration.setString( ExecutionConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin, NestedLoopJoin") - checkQuery1( + checkQuery( Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE), s1r(d"1", d"1", 1, 1.0), - "select count(*) from Table1 A, Table1 B where A.f0=B.f2", + table => table.as('a, 'b, 'c, 'd).join(table).where('a === 'f2).select(1.count), Seq(LONG), s1r(1L)) } @@ -829,10 +684,10 @@ class DecimalITCase extends BatchTestBase { tEnv.getConfig.getConfiguration.setString( ExecutionConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin, NestedLoopJoin") - checkQuery1( + checkQuery( Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE), s1r(d"1", d"1", 1, 1.0), - "select count(*) from Table1 A, Table1 B where A.f2=B.f0", + table => table.as('a, 'b, 'c, 'd).join(table).where('c === 'f0).select(1.count), Seq(LONG), s1r(1L)) } @@ -842,10 +697,10 @@ class DecimalITCase extends BatchTestBase { tEnv.getConfig.getConfiguration.setString( ExecutionConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin, NestedLoopJoin") - checkQuery1( + checkQuery( Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE), s1r(d"1", d"1", 1, 1.0), - "select count(*) from Table1 A, Table1 B where A.f0=B.f3", + table => table.as('a, 'b, 'c, 'd).join(table).where('a === 'f3).select(1.count), Seq(LONG), s1r(1L)) } @@ -854,21 +709,20 @@ class DecimalITCase extends BatchTestBase { def testJoin7(): Unit = { tEnv.getConfig.getConfiguration.setString( ExecutionConfigOptions.SQL_EXEC_DISABLED_OPERATORS, "HashJoin, NestedLoopJoin") - checkQuery1( + checkQuery( Seq(DECIMAL(8, 2), DECIMAL(8, 4), INT, DOUBLE), s1r(d"1", d"1", 1, 1.0), - "select count(*) from Table1 A, Table1 B where A.f3=B.f0", + table => table.as('a, 'b, 'c, 'd).join(table).where('a === 'f3).select(1.count), Seq(LONG), s1r(1L)) } - @Ignore @Test def testGroupBy(): Unit = { - checkQuery1( + checkQuery( Seq(DECIMAL(8, 2)), Seq(row(d"1"), row(d"3"), row(d"1.0"), row(d"2")), - "select count(*) from Table1 A group by f0", + table => table.groupBy('f0).select(1.count), Seq(LONG), Seq(row(2L), row(1L), row(1L))) } @@ -876,10 +730,10 @@ class DecimalITCase extends BatchTestBase { @Test def testOrderBy(): Unit = { env.setParallelism(1) // set sink parallelism to 1 - checkQuery1( + checkQuery( Seq(DECIMAL(8, 2)), Seq(row(d"1"), row(d"3"), row(d"1.0"), row(d"2")), - "select f0 from Table1 A order by f0", + table => table.select('f0).orderBy('f0), Seq(DECIMAL(8, 2)), Seq(row(d"1.00"), row(d"1.00"), row(d"2.00"), row(d"3.00")), isSorted = true) @@ -887,31 +741,29 @@ class DecimalITCase extends BatchTestBase { @Test def testSimpleNull(): Unit = { - checkQuery1( + checkQuery( Seq(DECIMAL(6, 3), DECIMAL(6, 3), DECIMAL(20, 10)), Seq(row(d"100.000", null, null)), - "select distinct(f0), f1, f2 from (select t1.f0, t1.f1, t1.f2 from Table1 t1 " + - "union all (SELECT * FROM Table1)) order by f0", + table => table.union(table).select('f0, 'f1, 'f2).orderBy('f0), Seq(DECIMAL(6, 3), DECIMAL(6, 3), DECIMAL(20, 10)), s1r(d"100.000", null, null)) } - @Ignore @Test def testAggAvgGroupBy(): Unit = { // null - checkQuery1( + checkQuery( Seq(DECIMAL(6, 3), DECIMAL(6, 3), DECIMAL(20, 10)), (0 until 100).map(_ => row(d"100.000", null, null)), - "select f0, avg(f1), avg(f2) from Table1 group by f0", + table => table.groupBy('f0).select('f0, 'f1.avg, 'f2.avg), Seq(DECIMAL(6, 3), DECIMAL(38, 6), DECIMAL(38, 10)), s1r(d"100.000", null, null)) - checkQuery1( + checkQuery( Seq(DECIMAL(6, 3), DECIMAL(6, 3), DECIMAL(20, 10)), (0 until 100).map(_ => row(d"100.000", d"100.000", d"1${10}")), - "select f0, avg(f1), avg(f2) from Table1 group by f0", + table => table.groupBy('f0).select('f0, 'f1.avg, 'f2.avg), Seq(DECIMAL(6, 3), DECIMAL(38, 6), DECIMAL(38, 10)), s1r(d"100.000", d"100.000000", d"1${10}")) } @@ -920,17 +772,17 @@ class DecimalITCase extends BatchTestBase { def testAggMinGroupBy(): Unit = { // null - checkQuery1( + checkQuery( Seq(DECIMAL(6, 3), DECIMAL(6, 3), DECIMAL(20, 10)), (0 until 100).map(_ => row(d"100.000", null, null)), - "select f0, min(f1), min(f2) from Table1 group by f0", + table => table.groupBy('f0).select('f0, 'f1.min, 'f2.min), Seq(DECIMAL(6, 3), DECIMAL(6, 3), DECIMAL(20, 10)), s1r(d"100.000", null, null)) - checkQuery1( + checkQuery( Seq(DECIMAL(6, 3), DECIMAL(6, 3), DECIMAL(20, 10)), (0 until 100).map(i => row(d"100.000", new JBigDecimal(100 - i), d"1${10}")), - "select f0, min(f1), min(f2) from Table1 group by f0", + table => table.groupBy('f0).select('f0, 'f1.min, 'f2.min), Seq(DECIMAL(6, 3), DECIMAL(6, 3), DECIMAL(20, 10)), s1r(d"100.000", d"1.000", d"1${10}")) }