This is an automated email from the ASF dual-hosted git repository. yamamuro pushed a commit to branch branch-3.0 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push: new 72c466e [SPARK-31761][SQL][3.0] cast integer to Long to avoid IntegerOverflow for IntegralDivide operator 72c466e is described below commit 72c466e0c37e4cc639040161699b6c0bffde70d5 Author: sandeep katta <sandeep.katta2...@gmail.com> AuthorDate: Sun May 24 21:39:16 2020 +0900 [SPARK-31761][SQL][3.0] cast integer to Long to avoid IntegerOverflow for IntegralDivide operator ### What changes were proposed in this pull request? `IntegralDivide` operator returns Long DataType, so integer overflow case should be handled. If the operands are of type Int it will be casted to Long ### Why are the changes needed? As `IntegralDivide` returns Long datatype, integer overflow should not happen ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added UT and also tested in the local cluster After fix  SQL Test After fix  Before Fix  Closes #28628 from sandeep-katta/branch3Backport. Authored-by: sandeep katta <sandeep.katta2...@gmail.com> Signed-off-by: Takeshi Yamamuro <yamam...@apache.org> --- .../spark/sql/catalyst/analysis/TypeCoercion.scala | 18 ++++++++++++++++ .../sql/catalyst/expressions/arithmetic.scala | 2 +- .../sql/catalyst/analysis/TypeCoercionSuite.scala | 24 ++++++++++++++++++++++ .../expressions/ArithmeticExpressionSuite.scala | 7 +------ .../sql-functions/sql-expression-schema.md | 2 +- .../resources/sql-tests/results/operators.sql.out | 8 ++++---- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 8 ++++++++ 7 files changed, 57 insertions(+), 12 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index c6e3f56..a6f8e12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -61,6 +61,7 @@ object TypeCoercion { IfCoercion :: StackCoercion :: Division :: + IntegralDivision :: ImplicitTypeCasts :: DateTimeOperations :: WindowFrameCoercion :: @@ -685,6 +686,23 @@ object TypeCoercion { } /** + * The DIV operator always returns long-type value. + * This rule cast the integral inputs to long type, to avoid overflow during calculation. + */ + object IntegralDivision extends TypeCoercionRule { + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { + case e if !e.childrenResolved => e + case d @ IntegralDivide(left, right) => + IntegralDivide(mayCastToLong(left), mayCastToLong(right)) + } + + private def mayCastToLong(expr: Expression): Expression = expr.dataType match { + case _: ByteType | _: ShortType | _: IntegerType => Cast(expr, LongType) + case _ => expr + } + } + + /** * Coerces the type of different branches of a CASE WHEN statement to a common type. */ object CaseWhenCoercion extends TypeCoercionRule { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 354845d..7c52183 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -412,7 +412,7 @@ case class IntegralDivide( left: Expression, right: Expression) extends DivModLike { - override def inputType: AbstractDataType = TypeCollection(IntegralType, DecimalType) + override def inputType: AbstractDataType = TypeCollection(LongType, DecimalType) override def dataType: DataType = LongType diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index e37555f..1ea1ddb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1559,6 +1559,30 @@ class TypeCoercionSuite extends AnalysisTest { Literal.create(null, DecimalType.SYSTEM_DEFAULT))) } } + + test("SPARK-31761: byte, short and int should be cast to long for IntegralDivide's datatype") { + val rules = Seq(FunctionArgumentConversion, Division, ImplicitTypeCasts) + // Casts Byte to Long + ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2.toByte, 1.toByte), + IntegralDivide(Cast(2.toByte, LongType), Cast(1.toByte, LongType))) + // Casts Short to Long + ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2.toShort, 1.toShort), + IntegralDivide(Cast(2.toShort, LongType), Cast(1.toShort, LongType))) + // Casts Integer to Long + ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2, 1), + IntegralDivide(Cast(2, LongType), Cast(1, LongType))) + // should not be any change for Long data types + ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2L, 1L), IntegralDivide(2L, 1L)) + // one of the operand is byte + ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2L, 1.toByte), + IntegralDivide(2L, Cast(1.toByte, LongType))) + // one of the operand is short + ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2.toShort, 1L), + IntegralDivide(Cast(2.toShort, LongType), 1L)) + // one of the operand is int + ruleTest(TypeCoercion.IntegralDivision, IntegralDivide(2, 1L), + IntegralDivide(Cast(2, LongType), 1L)) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 675f85f..f05598a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -173,13 +173,8 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } } - test("/ (Divide) for integral type") { - checkEvaluation(IntegralDivide(Literal(1.toByte), Literal(2.toByte)), 0L) - checkEvaluation(IntegralDivide(Literal(1.toShort), Literal(2.toShort)), 0L) - checkEvaluation(IntegralDivide(Literal(1), Literal(2)), 0L) + test("/ (Divide) for Long type") { checkEvaluation(IntegralDivide(Literal(1.toLong), Literal(2.toLong)), 0L) - checkEvaluation(IntegralDivide(positiveShortLit, negativeShortLit), 0L) - checkEvaluation(IntegralDivide(positiveIntLit, negativeIntLit), 0L) checkEvaluation(IntegralDivide(positiveLongLit, negativeLongLit), 0L) } diff --git a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md index c3ae2a7..9e24a54 100644 --- a/sql/core/src/test/resources/sql-functions/sql-expression-schema.md +++ b/sql/core/src/test/resources/sql-functions/sql-expression-schema.md @@ -136,7 +136,7 @@ | org.apache.spark.sql.catalyst.expressions.InputFileBlockLength | input_file_block_length | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.InputFileBlockStart | input_file_block_start | N/A | N/A | | org.apache.spark.sql.catalyst.expressions.InputFileName | input_file_name | N/A | N/A | -| org.apache.spark.sql.catalyst.expressions.IntegralDivide | div | SELECT 3 div 2 | struct<(3 div 2):bigint> | +| org.apache.spark.sql.catalyst.expressions.IntegralDivide | div | SELECT 3 div 2 | struct<(CAST(3 AS BIGINT) div CAST(2 AS BIGINT)):bigint> | | org.apache.spark.sql.catalyst.expressions.IsNaN | isnan | SELECT isnan(cast('NaN' as double)) | struct<isnan(CAST(NaN AS DOUBLE)):boolean> | | org.apache.spark.sql.catalyst.expressions.IsNotNull | isnotnull | SELECT isnotnull(1) | struct<(1 IS NOT NULL):boolean> | | org.apache.spark.sql.catalyst.expressions.IsNull | isnull | SELECT isnull(1) | struct<(1 IS NULL):boolean> | diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index a94a123..9accc57 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -157,7 +157,7 @@ NULL -- !query select 5 div 2 -- !query schema -struct<(5 div 2):bigint> +struct<(CAST(5 AS BIGINT) div CAST(2 AS BIGINT)):bigint> -- !query output 2 @@ -165,7 +165,7 @@ struct<(5 div 2):bigint> -- !query select 5 div 0 -- !query schema -struct<(5 div 0):bigint> +struct<(CAST(5 AS BIGINT) div CAST(0 AS BIGINT)):bigint> -- !query output NULL @@ -173,7 +173,7 @@ NULL -- !query select 5 div null -- !query schema -struct<(5 div CAST(NULL AS INT)):bigint> +struct<(CAST(5 AS BIGINT) div CAST(NULL AS BIGINT)):bigint> -- !query output NULL @@ -181,7 +181,7 @@ NULL -- !query select null div 5 -- !query schema -struct<(CAST(NULL AS INT) div 5):bigint> +struct<(CAST(NULL AS BIGINT) div CAST(5 AS BIGINT)):bigint> -- !query output NULL diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index d336f52..a23e583 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -3441,6 +3441,14 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark assert(SQLConf.get.getConf(SQLConf.CODEGEN_FALLBACK) === true) } } + + test("SPARK-31761: test byte, short, integer overflow for (Divide) integral type") { + checkAnswer(sql("Select -2147483648 DIV -1"), Seq(Row(Integer.MIN_VALUE.toLong * -1))) + checkAnswer(sql("select CAST(-128 as Byte) DIV CAST (-1 as Byte)"), + Seq(Row(Byte.MinValue.toLong * -1))) + checkAnswer(sql("select CAST(-32768 as short) DIV CAST (-1 as short)"), + Seq(Row(Short.MinValue.toLong * -1))) + } } case class Foo(bar: Option[String]) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org