This is an automated email from the ASF dual-hosted git repository. godfrey pushed a commit to branch release-1.14 in repository https://gitbox.apache.org/repos/asf/flink.git
commit 48fe11dcb81b2aac5c10f69ac4c33e8e67603d9b Author: xuyang <[email protected]> AuthorDate: Sat Oct 9 14:11:12 2021 +0800 [FLINK-15987][table-planner] Fix SELECT 1.0e0 / 0.0e0 throws NumberFormatException This closes #17436 (cherry picked from commit 8c603d8d07984ec136fe5b1778ccb39b039c5d75) --- .../table/planner/codegen/ExpressionReducer.scala | 71 ++++++++++++++++------ .../planner/expressions/ScalarFunctionsTest.scala | 7 +++ .../planner/expressions/SqlExpressionTest.scala | 17 ++++++ .../planner/expressions/TemporalTypesTest.scala | 53 +++++++++------- .../expressions/utils/ExpressionTestBase.scala | 1 + .../validation/ScalarFunctionsValidationTest.scala | 16 ----- 6 files changed, 106 insertions(+), 59 deletions(-) diff --git a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExpressionReducer.scala b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExpressionReducer.scala index 166011f..d657b8e 100644 --- a/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExpressionReducer.scala +++ b/flink-table/flink-table-planner/src/main/scala/org/apache/flink/table/planner/codegen/ExpressionReducer.scala @@ -32,7 +32,8 @@ import org.apache.flink.table.types.logical.RowType import org.apache.flink.table.util.TimestampStringUtils.fromLocalDateTime import org.apache.calcite.avatica.util.ByteString -import org.apache.calcite.rex.{RexBuilder, RexExecutor, RexNode} +import org.apache.calcite.rex.{RexBuilder, RexCall, RexExecutor, RexLiteral, RexNode, RexUtil} +import org.apache.calcite.sql.SqlKind import org.apache.calcite.sql.`type`.SqlTypeName import scala.collection.JavaConverters._ @@ -61,25 +62,7 @@ class ExpressionReducer( val pythonUDFExprs = new ListBuffer[RexNode]() - val literals = constExprs.asScala.map(e => (e.getType.getSqlTypeName, e)).flatMap { - - // Skip expressions that contain python functions because it's quite expensive to - // call Python UDFs during optimization phase. They will be optimized during the runtime. - case (_, e) if containsPythonCall(e) => - pythonUDFExprs += e - None - - // we don't support object literals yet, we skip those constant expressions - case (SqlTypeName.ANY, _) | - (SqlTypeName.OTHER, _) | - (SqlTypeName.ROW, _) | - (SqlTypeName.STRUCTURED, _) | - (SqlTypeName.ARRAY, _) | - (SqlTypeName.MAP, _) | - (SqlTypeName.MULTISET, _) => None - - case (_, e) => Some(e) - } + val literals = skipAndValidateExprs(rexBuilder, constExprs, pythonUDFExprs) val literalTypes = literals.map(e => FlinkTypeFactory.toLogicalType(e.getType)) val resultType = RowType.of(literalTypes: _*) @@ -244,6 +227,54 @@ class ExpressionReducer( targetType, true) } + + /** + * skip the expressions that can't be reduced now + * and validate the expressions + */ + private def skipAndValidateExprs( + rexBuilder: RexBuilder, + constExprs: java.util.List[RexNode], + pythonUDFExprs: ListBuffer[RexNode]): List[RexNode] ={ + constExprs.asScala.map(e => (e.getType.getSqlTypeName, e)).flatMap { + + // Skip expressions that contain python functions because it's quite expensive to + // call Python UDFs during optimization phase. They will be optimized during the runtime. + case (_, e) if containsPythonCall(e) => + pythonUDFExprs += e + None + + // we don't support object literals yet, we skip those constant expressions + case (SqlTypeName.ANY, _) | + (SqlTypeName.OTHER, _) | + (SqlTypeName.ROW, _) | + (SqlTypeName.STRUCTURED, _) | + (SqlTypeName.ARRAY, _) | + (SqlTypeName.MAP, _) | + (SqlTypeName.MULTISET, _) => None + + case (_, call: RexCall) => { + // to ensure the division is non-zero when the operator is DIVIDE + if (call.getOperator.getKind.equals(SqlKind.DIVIDE)) { + val ops = call.getOperands + val divisionLiteral = ops.get(ops.size() - 1) + + // according to BuiltInFunctionDefinitions, the DEVIDE's second op must be numeric + assert(RexUtil.isDeterministic(divisionLiteral)) + val divisionComparable = + divisionLiteral.asInstanceOf[RexLiteral].getValue.asInstanceOf[Comparable[Any]] + val zeroComparable = rexBuilder.makeExactLiteral( + new java.math.BigDecimal(0)) + .getValue.asInstanceOf[Comparable[Any]] + if (divisionComparable.compareTo(zeroComparable) == 0) { + throw new ArithmeticException("Division by zero") + } + } + Some(call) + } + case (_, e) => Some(e) + }.toList + } } /** diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/ScalarFunctionsTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/ScalarFunctionsTest.scala index c0bf7f7..5957a66 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/ScalarFunctionsTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/ScalarFunctionsTest.scala @@ -2552,6 +2552,13 @@ class ScalarFunctionsTest extends ScalarTypesTestBase { "LOG(cast (null AS DOUBLE), cast (null AS DOUBLE))", "null" ) + + // invalid log + val infiniteOrNaNException = "Infinite or NaN" + // Infinity + testExpectedSqlException("LOG(1, 100)", infiniteOrNaNException, classOf[NumberFormatException]) + // NaN + testExpectedSqlException("LOG(-1)", infiniteOrNaNException, classOf[NumberFormatException]) } @Test diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/SqlExpressionTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/SqlExpressionTest.scala index 721ad72..9c9cefa 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/SqlExpressionTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/SqlExpressionTest.scala @@ -144,6 +144,23 @@ class SqlExpressionTest extends ExpressionTestBase { // Decimal(2,1) / Decimal(10,0) => Decimal(23,12) testSqlApi("2.0/(-3)", "-0.666666666667") testSqlApi("-7.9/2", "-3.950000000000") + + // invalid division + val divisorZeroException = "Division by zero" + testExpectedSqlException( + "1/cast(0.00 as decimal)", divisorZeroException, classOf[ArithmeticException]) + testExpectedSqlException( + "1/cast(0.00 as double)", divisorZeroException, classOf[ArithmeticException]) + testExpectedSqlException( + "1/cast(0.00 as float)", divisorZeroException, classOf[ArithmeticException]) + testExpectedSqlException( + "1/cast(0 as tinyint)", divisorZeroException, classOf[ArithmeticException]) + testExpectedSqlException( + "1/cast(0 as smallint)", divisorZeroException, classOf[ArithmeticException]) + testExpectedSqlException( + "1/0", divisorZeroException, classOf[ArithmeticException]) + testExpectedSqlException( + "1/cast(0 as bigint)", divisorZeroException, classOf[ArithmeticException]) } @Test diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/TemporalTypesTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/TemporalTypesTest.scala index f854bfe..70bd5ca 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/TemporalTypesTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/TemporalTypesTest.scala @@ -18,12 +18,6 @@ package org.apache.flink.table.planner.expressions -import java.sql.Timestamp -import java.text.SimpleDateFormat -import java.time.{Instant, ZoneId, ZoneOffset} -import java.util.{Locale, TimeZone} -import java.lang.{Double => JDouble, Float => JFloat, Integer => JInt, Long => JLong} - import org.apache.flink.table.api._ import org.apache.flink.table.expressions.TimeIntervalUnit import org.apache.flink.table.planner.codegen.CodeGenException @@ -32,8 +26,15 @@ import org.apache.flink.table.planner.utils.DateTimeTestUtil import org.apache.flink.table.planner.utils.DateTimeTestUtil._ import org.apache.flink.table.types.DataType import org.apache.flink.types.Row + import org.junit.Test +import java.lang.{Double => JDouble, Float => JFloat, Integer => JInt, Long => JLong} +import java.sql.Timestamp +import java.text.SimpleDateFormat +import java.time.{Instant, ZoneId, ZoneOffset} +import java.util.Locale + class TemporalTypesTest extends ExpressionTestBase { @Test @@ -1116,6 +1117,29 @@ class TemporalTypesTest extends ExpressionTestBase { "1437699600") } + /** + * now Flink only support TIMESTAMP(3) as the return type in TO_TIMESTAMP + * See: https://issues.apache.org/jira/browse/FLINK-14925 + */ + @Test + def testToTimeStampFunctionWithHighPrecision(): Unit = { + testSqlApi( + "TO_TIMESTAMP('1970-01-01 00:00:00.123456789')", + "1970-01-01 00:00:00.123") + + testSqlApi( + "TO_TIMESTAMP('1970-01-01 00:00:00.12345', 'yyyy-MM-dd HH:mm:ss.SSSSS')", + "1970-01-01 00:00:00.123") + + testSqlApi( + "TO_TIMESTAMP('20000202 59:59.1234567', 'yyyyMMdd mm:ss.SSSSSSS')", + "2000-02-02 00:59:59.123") + + testSqlApi( + "TO_TIMESTAMP('1234567', 'SSSSSSS')", + "1970-01-01 00:00:00.123") + } + @Test def testHighPrecisionTimestamp(): Unit = { // EXTRACT should support millisecond/microsecond/nanosecond @@ -1167,15 +1191,6 @@ class TemporalTypesTest extends ExpressionTestBase { // "TIMESTAMP '1970-01-01 00:00:00.123455789')", // "1") - // TO_TIMESTAMP should support up to nanosecond - testSqlApi( - "TO_TIMESTAMP('1970-01-01 00:00:00.123456789')", - "1970-01-01 00:00:00.123456789") - - testSqlApi( - "TO_TIMESTAMP('1970-01-01 00:00:00.12345', 'yyyy-MM-dd HH:mm:ss.SSSSS')", - "1970-01-01 00:00:00.12345") - testSqlApi("TO_TIMESTAMP('abc')", "null") // TO_TIMESTAMP should complement YEAR/MONTH/DAY/HOUR/MINUTE/SECOND/NANO_OF_SECOND @@ -1183,14 +1198,6 @@ class TemporalTypesTest extends ExpressionTestBase { "TO_TIMESTAMP('2000020210', 'yyyyMMddHH')", "2000-02-02 10:00:00.000") - testSqlApi( - "TO_TIMESTAMP('20000202 59:59.1234567', 'yyyyMMdd mm:ss.SSSSSSS')", - "2000-02-02 00:59:59.1234567") - - testSqlApi( - "TO_TIMESTAMP('1234567', 'SSSSSSS')", - "1970-01-01 00:00:00.1234567") - // CAST between two TIMESTAMPs testSqlApi( "CAST(TIMESTAMP '1970-01-01 00:00:00.123456789' AS TIMESTAMP(6))", diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/utils/ExpressionTestBase.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/utils/ExpressionTestBase.scala index 567e963..6504f4f 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/utils/ExpressionTestBase.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/utils/ExpressionTestBase.scala @@ -303,6 +303,7 @@ abstract class ExpressionTestBase { exceptionClass: Class[_ <: Throwable], exprs: mutable.ArrayBuffer[_]): Unit = { val builder = new HepProgramBuilder() + builder.addRuleInstance(CoreRules.PROJECT_REDUCE_EXPRESSIONS) builder.addRuleInstance(CoreRules.PROJECT_TO_CALC) val hep = new HepPlanner(builder.build()) hep.setRoot(relNode) diff --git a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/validation/ScalarFunctionsValidationTest.scala b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/validation/ScalarFunctionsValidationTest.scala index d1760dc..43d3f60 100644 --- a/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/validation/ScalarFunctionsValidationTest.scala +++ b/flink-table/flink-table-planner/src/test/scala/org/apache/flink/table/planner/expressions/validation/ScalarFunctionsValidationTest.scala @@ -32,22 +32,6 @@ class ScalarFunctionsValidationTest extends ScalarTypesTestBase { // Math functions // ---------------------------------------------------------------------------------------------- - @Test - def testInvalidLog1(): Unit = { - testSqlApi( - "LOG(1, 100)", - "Infinity" - ) - } - - @Test - def testInvalidLog2(): Unit ={ - testSqlApi( - "LOG(-1)", - "NaN" - ) - } - @Test(expected = classOf[ValidationException]) def testInvalidBin1(): Unit = { testSqlApi("BIN(f12)", "101010") // float type
