This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push:
new e38310c74e6 Revert "[SPARK-43393][SQL] Address sequence expression
overflow bug"
e38310c74e6 is described below
commit e38310c74e6cae8c8c8489ffcbceb80ed37a7cae
Author: Dongjoon Hyun <[email protected]>
AuthorDate: Wed Nov 15 09:12:42 2023 -0800
Revert "[SPARK-43393][SQL] Address sequence expression overflow bug"
This reverts commit 41a7a4a3233772003aef380428acd9eaf39b9a93.
---
.../expressions/collectionOperations.scala | 48 ++++++-------------
.../expressions/CollectionExpressionsSuite.scala | 56 ++--------------------
2 files changed, 20 insertions(+), 84 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index c3c235fba67..ade4a6c5be7 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -22,8 +22,6 @@ import java.util.Comparator
import scala.collection.mutable
import scala.reflect.ClassTag
-import org.apache.spark.QueryContext
-import org.apache.spark.SparkException.internalError
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion,
UnresolvedAttribute, UnresolvedSeed}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
@@ -42,6 +40,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SQLOpenHashSet
import org.apache.spark.unsafe.UTF8StringBuilder
import org.apache.spark.unsafe.array.ByteArrayMethods
+import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
import org.apache.spark.unsafe.types.{ByteArray, CalendarInterval, UTF8String}
/**
@@ -3081,34 +3080,6 @@ case class Sequence(
}
object Sequence {
- private def prettyName: String = "sequence"
-
- def sequenceLength(start: Long, stop: Long, step: Long): Int = {
- try {
- val delta = Math.subtractExact(stop, start)
- if (delta == Long.MinValue && step == -1L) {
- // We must special-case division of Long.MinValue by -1 to catch
potential unchecked
- // overflow in next operation. Division does not have a builtin
overflow check. We
- // previously special-case div-by-zero.
- throw new ArithmeticException("Long overflow (Long.MinValue / -1)")
- }
- val len = if (stop == start) 1L else Math.addExact(1L, (delta / step))
- if (len > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
- throw
QueryExecutionErrors.createArrayWithElementsExceedLimitError(prettyName, len)
- }
- len.toInt
- } catch {
- // We handle overflows in the previous try block by raising an
appropriate exception.
- case _: ArithmeticException =>
- val safeLen =
- BigInt(1) + (BigInt(stop) - BigInt(start)) / BigInt(step)
- if (safeLen > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
- throw
QueryExecutionErrors.createArrayWithElementsExceedLimitError(prettyName,
safeLen)
- }
- throw internalError("Unreachable code reached.")
- case e: Exception => throw e
- }
- }
private type LessThanOrEqualFn = (Any, Any) => Boolean
@@ -3480,7 +3451,13 @@ object Sequence {
|| (estimatedStep == num.zero && start == stop),
s"Illegal sequence boundaries: $start to $stop by $step")
- sequenceLength(start.toLong, stop.toLong, estimatedStep.toLong)
+ val len = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) /
estimatedStep.toLong
+
+ require(
+ len <= MAX_ROUNDED_ARRAY_LENGTH,
+ s"Too long sequence: $len. Should be <= $MAX_ROUNDED_ARRAY_LENGTH")
+
+ len.toInt
}
private def genSequenceLengthCode(
@@ -3490,7 +3467,7 @@ object Sequence {
step: String,
estimatedStep: String,
len: String): String = {
- val calcFn = classOf[Sequence].getName + ".sequenceLength"
+ val longLen = ctx.freshName("longLen")
s"""
|if (!(($estimatedStep > 0 && $start <= $stop) ||
| ($estimatedStep < 0 && $start >= $stop) ||
@@ -3498,7 +3475,12 @@ object Sequence {
| throw new IllegalArgumentException(
| "Illegal sequence boundaries: " + $start + " to " + $stop + " by "
+ $step);
|}
- |int $len = $calcFn((long) $start, (long) $stop, (long) $estimatedStep);
+ |long $longLen = $stop == $start ? 1L : 1L + ((long) $stop - $start) /
$estimatedStep;
+ |if ($longLen > $MAX_ROUNDED_ARRAY_LENGTH) {
+ | throw new IllegalArgumentException(
+ | "Too long sequence: " + $longLen + ". Should be <=
$MAX_ROUNDED_ARRAY_LENGTH");
+ |}
+ |int $len = (int) $longLen;
""".stripMargin
}
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index d001006c58c..1787f6ac72d 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -34,7 +34,7 @@ import
org.apache.spark.sql.catalyst.util.DateTimeTestUtils.{outstandingZoneIds,
import org.apache.spark.sql.catalyst.util.IntervalUtils._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.array.ByteArrayMethods
+import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
import org.apache.spark.unsafe.types.UTF8String
class CollectionExpressionsSuite extends SparkFunSuite with
ExpressionEvalHelper {
@@ -769,6 +769,10 @@ class CollectionExpressionsSuite extends SparkFunSuite
with ExpressionEvalHelper
// test sequence boundaries checking
+ checkExceptionInExpression[IllegalArgumentException](
+ new Sequence(Literal(Int.MinValue), Literal(Int.MaxValue), Literal(1)),
+ EmptyRow, s"Too long sequence: 4294967296. Should be <=
$MAX_ROUNDED_ARRAY_LENGTH")
+
checkExceptionInExpression[IllegalArgumentException](
new Sequence(Literal(1), Literal(2), Literal(0)), EmptyRow, "boundaries:
1 to 2 by 0")
checkExceptionInExpression[IllegalArgumentException](
@@ -778,56 +782,6 @@ class CollectionExpressionsSuite extends SparkFunSuite
with ExpressionEvalHelper
checkExceptionInExpression[IllegalArgumentException](
new Sequence(Literal(1), Literal(2), Literal(-1)), EmptyRow,
"boundaries: 1 to 2 by -1")
- // SPARK-43393: test Sequence overflow checking
- checkErrorInExpression[SparkRuntimeException](
- new Sequence(Literal(Int.MinValue), Literal(Int.MaxValue), Literal(1)),
- errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
- parameters = Map(
- "numberOfElements" -> (BigInt(Int.MaxValue) - BigInt { Int.MinValue }
+ 1).toString,
- "functionName" -> toSQLId("sequence"),
- "maxRoundedArrayLength" ->
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
- "parameter" -> toSQLId("count")))
- checkErrorInExpression[SparkRuntimeException](
- new Sequence(Literal(0L), Literal(Long.MaxValue), Literal(1L)),
- errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
- parameters = Map(
- "numberOfElements" -> (BigInt(Long.MaxValue) + 1).toString,
- "functionName" -> toSQLId("sequence"),
- "maxRoundedArrayLength" ->
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
- "parameter" -> toSQLId("count")))
- checkErrorInExpression[SparkRuntimeException](
- new Sequence(Literal(0L), Literal(Long.MinValue), Literal(-1L)),
- errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
- parameters = Map(
- "numberOfElements" -> ((0 - BigInt(Long.MinValue)) + 1).toString(),
- "functionName" -> toSQLId("sequence"),
- "maxRoundedArrayLength" ->
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
- "parameter" -> toSQLId("count")))
- checkErrorInExpression[SparkRuntimeException](
- new Sequence(Literal(Long.MinValue), Literal(Long.MaxValue),
Literal(1L)),
- errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
- parameters = Map(
- "numberOfElements" -> (BigInt(Long.MaxValue) - BigInt { Long.MinValue
} + 1).toString,
- "functionName" -> toSQLId("sequence"),
- "maxRoundedArrayLength" ->
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
- "parameter" -> toSQLId("count")))
- checkErrorInExpression[SparkRuntimeException](
- new Sequence(Literal(Long.MaxValue), Literal(Long.MinValue),
Literal(-1L)),
- errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
- parameters = Map(
- "numberOfElements" -> (BigInt(Long.MaxValue) - BigInt { Long.MinValue
} + 1).toString,
- "functionName" -> toSQLId("sequence"),
- "maxRoundedArrayLength" ->
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
- "parameter" -> toSQLId("count")))
- checkErrorInExpression[SparkRuntimeException](
- new Sequence(Literal(Long.MaxValue), Literal(-1L), Literal(-1L)),
- errorClass = "COLLECTION_SIZE_LIMIT_EXCEEDED.PARAMETER",
- parameters = Map(
- "numberOfElements" -> (BigInt(Long.MaxValue) - BigInt { -1L } +
1).toString,
- "functionName" -> toSQLId("sequence"),
- "maxRoundedArrayLength" ->
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH.toString(),
- "parameter" -> toSQLId("count")))
-
// test sequence with one element (zero step or equal start and stop)
checkEvaluation(new Sequence(Literal(1), Literal(1), Literal(-1)), Seq(1))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]