This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new f5900a50c9e3 [SPARK-43393][SQL][3.4] Address sequence expression
overflow bug
f5900a50c9e3 is described below
commit f5900a50c9e3223ec9c2a48104a90e96e328a0ec
Author: Deepayan Patra <[email protected]>
AuthorDate: Fri Nov 17 15:35:43 2023 -0800
[SPARK-43393][SQL][3.4] Address sequence expression overflow bug
### What changes were proposed in this pull request?
Spark has a (long-standing) overflow bug in the `sequence` expression.
Consider the following operations:
```
spark.sql("CREATE TABLE foo (l LONG);")
spark.sql(s"INSERT INTO foo VALUES (${Long.MaxValue});")
spark.sql("SELECT sequence(0, l) FROM foo;").collect()
```
The result of these operations will be:
```
Array[org.apache.spark.sql.Row] = Array([WrappedArray()])
```
an unintended consequence of overflow.
The sequence is applied to values `0` and `Long.MaxValue` with a step size
of `1` which uses a length computation defined
[here](https://github.com/apache/spark/blob/16411188c7ba6cb19c46a2bd512b2485a4c03e2c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L3451).
In this calculation, with `start = 0`, `stop = Long.MaxValue`, and `step = 1`,
the calculated `len` overflows to `Long.MinValue`. The computation, in binary
looks like:
```
0111111111111111111111111111111111111111111111111111111111111111
- 0000000000000000000000000000000000000000000000000000000000000000
------------------------------------------------------------------
0111111111111111111111111111111111111111111111111111111111111111
/ 0000000000000000000000000000000000000000000000000000000000000001
------------------------------------------------------------------
0111111111111111111111111111111111111111111111111111111111111111
+ 0000000000000000000000000000000000000000000000000000000000000001
------------------------------------------------------------------
1000000000000000000000000000000000000000000000000000000000000000
```
The following
[check](https://github.com/apache/spark/blob/16411188c7ba6cb19c46a2bd512b2485a4c03e2c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L3454)
passes as the negative `Long.MinValue` is still `<= MAX_ROUNDED_ARRAY_LENGTH`.
The following cast to `toInt` uses this representation and [truncates the upper
bits](https://github.com/apache/spark/blob/16411188c7ba6cb19c46a2bd512b2485a4c03e2c/sql/catalyst/src/main/scala/org/apache/spa
[...]
Other overflows are similarly problematic.
This PR addresses the issue by checking numeric operations in the length
computation for overflow.
### Why are the changes needed?
There is a correctness bug from overflow in the `sequence` expression.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Tests added in `CollectionExpressionsSuite.scala`.
Closes #43819 from thepinetree/spark-sequence-overflow-3.4.
Authored-by: Deepayan Patra <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../expressions/collectionOperations.scala | 47 +++++++++++++++-------
.../expressions/CollectionExpressionsSuite.scala | 44 +++++++++++++++++---
2 files changed, 71 insertions(+), 20 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 629ae0499b4d..f7e03649d5d1 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -22,6 +22,7 @@ import java.util.Comparator
import scala.collection.mutable
import scala.reflect.ClassTag
+import org.apache.spark.SparkException.internalError
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion,
UnresolvedAttribute, UnresolvedSeed}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
@@ -39,7 +40,6 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SQLOpenHashSet
import org.apache.spark.unsafe.UTF8StringBuilder
import org.apache.spark.unsafe.array.ByteArrayMethods
-import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
import org.apache.spark.unsafe.types.{ByteArray, CalendarInterval, UTF8String}
/**
@@ -3011,6 +3011,34 @@ case class Sequence(
}
object Sequence {
+ private def prettyName: String = "sequence"
+
+ def sequenceLength(start: Long, stop: Long, step: Long): Int = {
+ try {
+ val delta = Math.subtractExact(stop, start)
+ if (delta == Long.MinValue && step == -1L) {
+ // We must special-case division of Long.MinValue by -1 to catch
potential unchecked
+ // overflow in next operation. Division does not have a builtin
overflow check. We
+ // previously special-case div-by-zero.
+ throw new ArithmeticException("Long overflow (Long.MinValue / -1)")
+ }
+ val len = if (stop == start) 1L else Math.addExact(1L, (delta / step))
+ if (len > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+ throw QueryExecutionErrors.createArrayWithElementsExceedLimitError(len)
+ }
+ len.toInt
+ } catch {
+ // We handle overflows in the previous try block by raising an
appropriate exception.
+ case _: ArithmeticException =>
+ val safeLen =
+ BigInt(1) + (BigInt(stop) - BigInt(start)) / BigInt(step)
+ if (safeLen > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+ throw
QueryExecutionErrors.createArrayWithElementsExceedLimitError(safeLen)
+ }
+ throw internalError("Unreachable code reached.")
+ case e: Exception => throw e
+ }
+ }
private type LessThanOrEqualFn = (Any, Any) => Boolean
@@ -3382,13 +3410,7 @@ object Sequence {
|| (estimatedStep == num.zero && start == stop),
s"Illegal sequence boundaries: $start to $stop by $step")
- val len = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) /
estimatedStep.toLong
-
- require(
- len <= MAX_ROUNDED_ARRAY_LENGTH,
- s"Too long sequence: $len. Should be <= $MAX_ROUNDED_ARRAY_LENGTH")
-
- len.toInt
+ sequenceLength(start.toLong, stop.toLong, estimatedStep.toLong)
}
private def genSequenceLengthCode(
@@ -3398,7 +3420,7 @@ object Sequence {
step: String,
estimatedStep: String,
len: String): String = {
- val longLen = ctx.freshName("longLen")
+ val calcFn = classOf[Sequence].getName + ".sequenceLength"
s"""
|if (!(($estimatedStep > 0 && $start <= $stop) ||
| ($estimatedStep < 0 && $start >= $stop) ||
@@ -3406,12 +3428,7 @@ object Sequence {
| throw new IllegalArgumentException(
| "Illegal sequence boundaries: " + $start + " to " + $stop + " by "
+ $step);
|}
- |long $longLen = $stop == $start ? 1L : 1L + ((long) $stop - $start) /
$estimatedStep;
- |if ($longLen > $MAX_ROUNDED_ARRAY_LENGTH) {
- | throw new IllegalArgumentException(
- | "Too long sequence: " + $longLen + ". Should be <=
$MAX_ROUNDED_ARRAY_LENGTH");
- |}
- |int $len = (int) $longLen;
+ |int $len = $calcFn((long) $start, (long) $stop, (long) $estimatedStep);
""".stripMargin
}
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 1787f6ac72dd..99eece31a1ef 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -34,7 +34,7 @@ import
org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{outstandingZoneIds,
import org.apache.spark.sql.catalyst.util.IntervalUtils._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
+import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.UTF8String
class CollectionExpressionsSuite extends SparkFunSuite with
ExpressionEvalHelper {
@@ -769,10 +769,6 @@ class CollectionExpressionsSuite extends SparkFunSuite
with ExpressionEvalHelper
// test sequence boundaries checking
- checkExceptionInExpression[IllegalArgumentException](
- new Sequence(Literal(Int.MinValue), Literal(Int.MaxValue), Literal(1)),
- EmptyRow, s"Too long sequence: 4294967296. Should be <=
$MAX_ROUNDED_ARRAY_LENGTH")
-
checkExceptionInExpression[IllegalArgumentException](
new Sequence(Literal(1), Literal(2), Literal(0)), EmptyRow, "boundaries:
1 to 2 by 0")
checkExceptionInExpression[IllegalArgumentException](
@@ -782,6 +778,44 @@ class CollectionExpressionsSuite extends SparkFunSuite
with ExpressionEvalHelper
checkExceptionInExpression[IllegalArgumentException](
new Sequence(Literal(1), Literal(2), Literal(-1)), EmptyRow,
"boundaries: 1 to 2 by -1")
+ // SPARK-43393: test Sequence overflow checking
+ checkErrorInExpression[SparkRuntimeException](
+ new Sequence(Literal(Int.MinValue), Literal(Int.MaxValue), Literal(1)),
+ errorClass = "_LEGACY_ERROR_TEMP_2161",
+ parameters = Map(
+ "count" -> (BigInt(Int.MaxValue) - BigInt { Int.MinValue } +
1).toString,
+ "maxRoundedArrayLength" ->
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))
+ checkErrorInExpression[SparkRuntimeException](
+ new Sequence(Literal(0L), Literal(Long.MaxValue), Literal(1L)),
+ errorClass = "_LEGACY_ERROR_TEMP_2161",
+ parameters = Map(
+ "count" -> (BigInt(Long.MaxValue) + 1).toString,
+ "maxRoundedArrayLength" ->
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))
+ checkErrorInExpression[SparkRuntimeException](
+ new Sequence(Literal(0L), Literal(Long.MinValue), Literal(-1L)),
+ errorClass = "_LEGACY_ERROR_TEMP_2161",
+ parameters = Map(
+ "count" -> ((0 - BigInt(Long.MinValue)) + 1).toString(),
+ "maxRoundedArrayLength" ->
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))
+ checkErrorInExpression[SparkRuntimeException](
+ new Sequence(Literal(Long.MinValue), Literal(Long.MaxValue),
Literal(1L)),
+ errorClass = "_LEGACY_ERROR_TEMP_2161",
+ parameters = Map(
+ "count" -> (BigInt(Long.MaxValue) - BigInt { Long.MinValue } +
1).toString,
+ "maxRoundedArrayLength" ->
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))
+ checkErrorInExpression[SparkRuntimeException](
+ new Sequence(Literal(Long.MaxValue), Literal(Long.MinValue),
Literal(-1L)),
+ errorClass = "_LEGACY_ERROR_TEMP_2161",
+ parameters = Map(
+ "count" -> (BigInt(Long.MaxValue) - BigInt { Long.MinValue } +
1).toString,
+ "maxRoundedArrayLength" ->
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))
+ checkErrorInExpression[SparkRuntimeException](
+ new Sequence(Literal(Long.MaxValue), Literal(-1L), Literal(-1L)),
+ errorClass = "_LEGACY_ERROR_TEMP_2161",
+ parameters = Map(
+ "count" -> (BigInt(Long.MaxValue) - BigInt { -1L } + 1).toString,
+ "maxRoundedArrayLength" ->
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString()))
+
// test sequence with one element (zero step or equal start and stop)
checkEvaluation(new Sequence(Literal(1), Literal(1), Literal(-1)), Seq(1))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]