This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new afc4c49927cb [SPARK-43393][SQL] Address sequence expression overflow 
bug
afc4c49927cb is described below

commit afc4c49927cb7f0f2a7f24a42c4fe497796dd9e3
Author: Deepayan Patra <deepayan.pa...@databricks.com>
AuthorDate: Wed Nov 15 14:27:34 2023 +0800

    [SPARK-43393][SQL] Address sequence expression overflow bug
    
    ### What changes were proposed in this pull request?
    Spark has a (long-standing) overflow bug in the `sequence` expression.
    
    Consider the following operations:
    ```
    spark.sql("CREATE TABLE foo (l LONG);")
    spark.sql(s"INSERT INTO foo VALUES (${Long.MaxValue});")
    spark.sql("SELECT sequence(0, l) FROM foo;").collect()
    ```
    
    The result of these operations will be:
    ```
    Array[org.apache.spark.sql.Row] = Array([WrappedArray()])
    ```
    an unintended consequence of overflow.
    
    The sequence is applied to values `0` and `Long.MaxValue` with a step size 
of `1` which uses a length computation defined 
[here](https://github.com/apache/spark/blob/16411188c7ba6cb19c46a2bd512b2485a4c03e2c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L3451).
 In this calculation, with `start = 0`, `stop = Long.MaxValue`, and `step = 1`, 
the calculated `len` overflows to `Long.MinValue`. The computation, in binary 
looks like:
    
    ```
      0111111111111111111111111111111111111111111111111111111111111111
    - 0000000000000000000000000000000000000000000000000000000000000000
    ------------------------------------------------------------------
      0111111111111111111111111111111111111111111111111111111111111111
    / 0000000000000000000000000000000000000000000000000000000000000001
    ------------------------------------------------------------------
      0111111111111111111111111111111111111111111111111111111111111111
    + 0000000000000000000000000000000000000000000000000000000000000001
    ------------------------------------------------------------------
      1000000000000000000000000000000000000000000000000000000000000000
    ```
    
    The following 
[check](https://github.com/apache/spark/blob/16411188c7ba6cb19c46a2bd512b2485a4c03e2c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L3454)
 passes as the negative `Long.MinValue` is still `<= MAX_ROUNDED_ARRAY_LENGTH`. 
The following cast to `toInt` uses this representation and [truncates the upper 
bits](https://github.com/apache/spark/blob/16411188c7ba6cb19c46a2bd512b2485a4c03e2c/sql/catalyst/src/main/scala/org/apache/spa
 [...]
    
    Other overflows are similarly problematic.
    
    This PR addresses the issue by checking numeric operations in the length 
computation for overflow.
    
    ### Why are the changes needed?
    There is a correctness bug from overflow in the `sequence` expression.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Tests added in `CollectionExpressionsSuite.scala`.
    
    Closes #41072 from thepinetree/spark-sequence-overflow.
    
    Authored-by: Deepayan Patra <deepayan.pa...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../expressions/collectionOperations.scala         | 47 ++++++++++++------
 .../expressions/CollectionExpressionsSuite.scala   | 56 ++++++++++++++++++++--
 2 files changed, 83 insertions(+), 20 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 196c4a6cdd69..070d9c34f294 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -23,6 +23,7 @@ import scala.collection.mutable
 import scala.reflect.ClassTag
 
 import org.apache.spark.QueryContext
+import org.apache.spark.SparkException.internalError
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion, 
UnresolvedAttribute, UnresolvedSeed}
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
@@ -41,7 +42,6 @@ import org.apache.spark.sql.types._
 import org.apache.spark.sql.util.SQLOpenHashSet
 import org.apache.spark.unsafe.UTF8StringBuilder
 import org.apache.spark.unsafe.array.ByteArrayMethods
-import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
 import org.apache.spark.unsafe.types.{ByteArray, CalendarInterval, UTF8String}
 
 /**
@@ -3122,6 +3122,34 @@ case class Sequence(
 }
 
 object Sequence {
+  private def prettyName: String = "sequence"
+
+  def sequenceLength(start: Long, stop: Long, step: Long): Int = {
+    try {
+      val delta = Math.subtractExact(stop, start)
+      if (delta == Long.MinValue && step == -1L) {
+        // We must special-case division of Long.MinValue by -1 to catch 
potential unchecked
+        // overflow in next operation. Division does not have a builtin 
overflow check. We
+        // previously special-case div-by-zero.
+        throw new ArithmeticException("Long overflow (Long.MinValue / -1)")
+      }
+      val len = if (stop == start) 1L else Math.addExact(1L, (delta / step))
+      if (len > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+        throw 
QueryExecutionErrors.createArrayWithElementsExceedLimitError(prettyName, len)
+      }
+      len.toInt
+    } catch {
+      // We handle overflows in the previous try block by raising an 
appropriate exception.
+      case _: ArithmeticException =>
+        val safeLen =
+          BigInt(1) + (BigInt(stop) - BigInt(start)) / BigInt(step)
+        if (safeLen > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+          throw 
QueryExecutionErrors.createArrayWithElementsExceedLimitError(prettyName, 
safeLen)
+        }
+        throw internalError("Unreachable code reached.")
+      case e: Exception => throw e
+    }
+  }
 
   private type LessThanOrEqualFn = (Any, Any) => Boolean
 
@@ -3493,13 +3521,7 @@ object Sequence {
         || (estimatedStep == num.zero && start == stop),
       s"Illegal sequence boundaries: $start to $stop by $step")
 
-    val len = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) / 
estimatedStep.toLong
-
-    require(
-      len <= MAX_ROUNDED_ARRAY_LENGTH,
-      s"Too long sequence: $len. Should be <= $MAX_ROUNDED_ARRAY_LENGTH")
-
-    len.toInt
+    sequenceLength(start.toLong, stop.toLong, estimatedStep.toLong)
   }
 
   private def genSequenceLengthCode(
@@ -3509,7 +3531,7 @@ object Sequence {
       step: String,
       estimatedStep: String,
       len: String): String = {
-    val longLen = ctx.freshName("longLen")
+    val calcFn = classOf[Sequence].getName + ".sequenceLength"
     s"""
        |if (!(($estimatedStep > 0 && $start <= $stop) ||
        |  ($estimatedStep < 0 && $start >= $stop) ||
@@ -3517,12 +3539,7 @@ object Sequence {
        |  throw new IllegalArgumentException(
        |    "Illegal sequence boundaries: " + $start + " to " + $stop + " by " 
+ $step);
        |}
-       |long $longLen = $stop == $start ? 1L : 1L + ((long) $stop - $start) / 
$estimatedStep;
-       |if ($longLen > $MAX_ROUNDED_ARRAY_LENGTH) {
-       |  throw new IllegalArgumentException(
-       |    "Too long sequence: " + $longLen + ". Should be <= 
$MAX_ROUNDED_ARRAY_LENGTH");
-       |}
-       |int $len = (int) $longLen;
+       |int $len = $calcFn((long) $start, (long) $stop, (long) $estimatedStep);
        """.stripMargin
   }
 }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index b3a21bd9565f..69db27779389 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.util.IntervalUtils._
 import org.apache.spark.sql.errors.DataTypeErrorsBase
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
+import org.apache.spark.unsafe.array.ByteArrayMethods
 import org.apache.spark.unsafe.types.UTF8String
 
 class CollectionExpressionsSuite
@@ -795,10 +795,6 @@ class CollectionExpressionsSuite
 
     // test sequence boundaries checking
 
-    checkExceptionInExpression[IllegalArgumentException](
-      new Sequence(Literal(Int.MinValue), Literal(Int.MaxValue), Literal(1)),
-      EmptyRow, s"Too long sequence: 4294967296. Should be <= 
$MAX_ROUNDED_ARRAY_LENGTH")
-
     checkExceptionInExpression[IllegalArgumentException](
       new Sequence(Literal(1), Literal(2), Literal(0)), EmptyRow, "boundaries: 
1 to 2 by 0")
     checkExceptionInExpression[IllegalArgumentException](
@@ -808,6 +804,56 @@ class CollectionExpressionsSuite
     checkExceptionInExpression[IllegalArgumentException](
       new Sequence(Literal(1), Literal(2), Literal(-1)), EmptyRow, 
"boundaries: 1 to 2 by -1")
 
+    // SPARK-43393: test Sequence overflow checking
+    checkErrorInExpression[SparkRuntimeException](
+      new Sequence(Literal(Int.MinValue), Literal(Int.MaxValue), Literal(1)),
+      errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
+      parameters = Map(
+        "numberOfElements" -> (BigInt(Int.MaxValue) - BigInt { Int.MinValue } 
+ 1).toString,
+        "functionName" -> toSQLId("sequence"),
+        "maxRoundedArrayLength" -> 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
+        "parameter" -> toSQLId("count")))
+    checkErrorInExpression[SparkRuntimeException](
+      new Sequence(Literal(0L), Literal(Long.MaxValue), Literal(1L)),
+      errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
+      parameters = Map(
+        "numberOfElements" -> (BigInt(Long.MaxValue) + 1).toString,
+        "functionName" -> toSQLId("sequence"),
+        "maxRoundedArrayLength" -> 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
+        "parameter" -> toSQLId("count")))
+    checkErrorInExpression[SparkRuntimeException](
+      new Sequence(Literal(0L), Literal(Long.MinValue), Literal(-1L)),
+      errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
+      parameters = Map(
+        "numberOfElements" -> ((0 - BigInt(Long.MinValue)) + 1).toString(),
+        "functionName" -> toSQLId("sequence"),
+        "maxRoundedArrayLength" -> 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
+        "parameter" -> toSQLId("count")))
+    checkErrorInExpression[SparkRuntimeException](
+      new Sequence(Literal(Long.MinValue), Literal(Long.MaxValue), 
Literal(1L)),
+      errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
+      parameters = Map(
+        "numberOfElements" -> (BigInt(Long.MaxValue) - BigInt { Long.MinValue 
} + 1).toString,
+        "functionName" -> toSQLId("sequence"),
+        "maxRoundedArrayLength" -> 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
+        "parameter" -> toSQLId("count")))
+    checkErrorInExpression[SparkRuntimeException](
+      new Sequence(Literal(Long.MaxValue), Literal(Long.MinValue), 
Literal(-1L)),
+      errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
+      parameters = Map(
+        "numberOfElements" -> (BigInt(Long.MaxValue) - BigInt { Long.MinValue 
} + 1).toString,
+        "functionName" -> toSQLId("sequence"),
+        "maxRoundedArrayLength" -> 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
+        "parameter" -> toSQLId("count")))
+    checkErrorInExpression[SparkRuntimeException](
+      new Sequence(Literal(Long.MaxValue), Literal(-1L), Literal(-1L)),
+      errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
+      parameters = Map(
+        "numberOfElements" -> (BigInt(Long.MaxValue) - BigInt { -1L } + 
1).toString,
+        "functionName" -> toSQLId("sequence"),
+        "maxRoundedArrayLength" -> 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
+        "parameter" -> toSQLId("count")))
+
     // test sequence with one element (zero step or equal start and stop)
 
     checkEvaluation(new Sequence(Literal(1), Literal(1), Literal(-1)), Seq(1))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to