thepinetree commented on code in PR #41072:
URL: https://github.com/apache/spark/pull/41072#discussion_r1392991421
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala:
##########
@@ -3122,6 +3122,34 @@ case class Sequence(
}
object Sequence {
+ private def prettyName: String = "sequence"
+
+ def sequenceLength(start: Long, stop: Long, step: Long): Int = {
+ try {
+ val delta = Math.subtractExact(stop, start)
+ if (delta == Long.MinValue && step == -1L) {
+ // We must special-case division of Long.MinValue by -1 to catch
potential unchecked
+ // overflow in next operation. Division does not have a builtin
overflow check. We
+ // previously special-case div-by-zero.
+ throw new ArithmeticException("Long overflow (Long.MinValue / -1)")
+ }
+ val len = if (stop == start) 1L else Math.addExact(1L, (delta / step))
+ if (len > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+ throw
QueryExecutionErrors.createArrayWithElementsExceedLimitError(prettyName, len)
+ }
+ len.toInt
+ } catch {
+ // We handle overflows in the previous try block by raising an
appropriate exception.
+ case _: ArithmeticException =>
+ val safeLen =
+ BigInt(1) + (BigInt(stop) - BigInt(start)) / BigInt(step)
+ if (safeLen > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
Review Comment:
I personally like the current exception better since it's more descriptive
of the actual problem -- trying to create too large an array (with the user's
intended size) and what the limit is. If strong opinion, I can change to an
assertion.
##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala:
##########
@@ -3509,20 +3531,15 @@ object Sequence {
step: String,
estimatedStep: String,
len: String): String = {
- val longLen = ctx.freshName("longLen")
+ val calcFn = "Sequence.sequenceLength"
Review Comment:
Not sure exactly how, but my hypothesis is that these factors all play a
role:
* Sequence object ends up in the same compilation unit as this function (not
sure if this is expected)
* `sequenceLength` function is effectively static in Scala and publicly
accessible
I noticed some other functions in this file do this as well -- e.g.
https://github.com/apache/spark/blob/128f5523194d5241c7b0f08b5be183288128ba16/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L613
I do like your suggestion better though, cleaner and easier to understand.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]