Github user kiszk commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21155#discussion_r184125363
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 ---
    @@ -1059,3 +1063,282 @@ case class Flatten(child: Expression) extends 
UnaryExpression {
     
       override def prettyName: String = "flatten"
     }
    +
    +object Sequence {
    +  private def defaultStepLiteralForType(dt: DataType) = dt match {
    +    case DateType | TimestampType => Literal(new CalendarInterval(0, 
MICROS_PER_DAY))
    +    case dt: NumericType => Literal(dt.numeric.one)
    +  }
    +
    +  private def defaultStepExpression(start: Expression, stop: Expression): 
Expression = {
    +    val one =
    +      if (start.resolved) {
    +        defaultStepLiteralForType(start.dataType)
    +      } else {
    +        UnresolvedLiteral(start, defaultStepLiteralForType)
    +      }
    +    CaseWhen(
    +      Seq((LessThanOrEqual(start, stop), one)),
    +      UnaryMinus(one))
    +  }
    +
    +  def getSequenceLength[U](start: U, stop: U, step: U)(implicit num: 
Integral[U]): Int = {
    +    import num._
    +    require(
    +      step > num.zero && start <= stop
    +        || step < num.zero && start >= stop
    +        || step == 0 && start == stop,
    +      s"Illegal sequence boundaries: $start to $stop by $step")
    +
    +    val len = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) 
/ step.toLong
    +
    +    require(len <= Int.MaxValue, s"Too long sequence: $len. Should be <= 
${Int.MaxValue}")
    +
    +    len.toInt
    +  }
    +
    +  def genSequenceLengthCode(start: String, stop: String, step: String)
    +                       (len: String)
    +                       (implicit ctx: CodegenContext): String = {
    +    val longLen = ctx.freshName("longLen")
    +    s"""
    +       |  if (!($step > 0 && $start <= $stop ||
    +       |    $step < 0 && $start >= $stop ||
    +       |    $step == 0 && $start == $stop)) {
    +       |    throw new IllegalArgumentException(
    +       |      "Illegal sequence boundaries: " + $start + " to " + $stop + 
" by " + $step);
    +       |  }
    +       |  long $longLen = $stop == $start ? 1L : 1L + ((long) $stop - 
$start) / $step;
    +       |  if ($longLen > Integer.MAX_VALUE) {
    +       |    throw new IllegalArgumentException(
    +       |      "Too long sequence: " + $longLen + ". Should be <= 
${Int.MaxValue}");
    +       |  }
    +       |  int $len = (int) $longLen;
    +       """.stripMargin
    +  }
    +
    +  trait SequenceImpl {
    +    def eval(input1: Any, input2: Any, input3: Any): Any
    +    def genCode(start: String, stop: String, step: String)
    +               (arr: String, elemType: String)
    +               (implicit ctx: CodegenContext): String
    +  }
    +}
    +
    +case class Sequence(left: Expression,
    +                    middle: Expression,
    +                    right: Expression,
    +                    timeZoneId: Option[String] = None)
    +  extends TernaryExpression with TimeZoneAwareExpression {
    +
    +  import Sequence._
    +
    +  class IntegralSequenceImpl[T: ClassTag](implicit num: Integral[T]) 
extends SequenceImpl {
    +
    +    override def eval(input1: Any, input2: Any, input3: Any): Array[T] = {
    +      import num._
    +
    +      val start = input1.asInstanceOf[T]
    +      val stop = input2.asInstanceOf[T]
    +      val step = input3.asInstanceOf[T]
    +
    +      var i: Int = getSequenceLength(start, stop, step)
    +      val arr = new Array[T](i)
    +      while (i > 0) {
    +        i -= 1
    +        arr(i) = start + step * num.fromInt(i)
    +      }
    +      arr
    +    }
    +
    +    override def genCode(start: String, stop: String, step: String)
    +                        (arr: String, elemType: String)
    +                        (implicit ctx: CodegenContext): String = {
    +      val i = ctx.freshName("i")
    +      s"""
    +         |  ${genSequenceLengthCode(start, stop, step)(i)}
    +         |  $arr = new $elemType[$i];
    +         |  while ($i > 0) {
    +         |    $i--;
    +         |    $arr[$i] = ($elemType) ($start + $step * $i);
    +         |  }
    +         """.stripMargin
    +    }
    +  }
    +
    +  abstract class TemporalSequence[T: ClassTag](dt: IntegralType, scale: 
Long, fromLong: Long => T)
    +                                              (implicit num: Integral[T])
    +    extends SequenceImpl {
    +
    +    private val backedSequenceImpl = new IntegralSequenceImpl[T]
    +    private val microsPerMonth = 28 * CalendarInterval.MICROS_PER_DAY
    +
    +    override def eval(input1: Any, input2: Any, input3: Any): Array[T] = {
    +      val start = input1.asInstanceOf[T]
    +      val stop = input2.asInstanceOf[T]
    +      val step = input3.asInstanceOf[CalendarInterval]
    +      val stepMonths = step.months
    +      val stepMicros = step.microseconds
    +
    +      if (stepMonths == 0) {
    +        backedSequenceImpl.eval(start, stop, fromLong(stepMicros / scale))
    +
    +      } else {
    +        // To estimate the resulted array length we need to make 
assumptions
    +        // about a month length in microseconds
    +        val minIntervalLengthInMicros = Math.abs(stepMicros) + stepMonths 
* microsPerMonth
    +        val startMicros: Long = num.toLong(start) * scale
    +        val stopMicros: Long = num.toLong(stop) * scale
    +        val maxEstimatedArrayLength =
    +          getSequenceLength(startMicros, stopMicros, 
minIntervalLengthInMicros)
    +
    +        val stepSign = if (stopMicros > startMicros) +1 else -1
    +        val exclusiveItem = stopMicros + stepSign
    +        val arr = new Array[T](maxEstimatedArrayLength)
    +        var t = startMicros
    +        var i = 0
    +
    +        while (t < exclusiveItem ^ stepSign < 0) {
    +          arr(i) = fromLong(t / scale)
    +          t = timestampAddInterval(t, stepMonths, stepMicros, timeZone)
    +          i += 1
    +        }
    +
    +        // truncate array to the correct length
    +        if (arr.length == i) {
    +          arr
    +        } else {
    +          arr.slice(0, i)
    +        }
    +      }
    +    }
    +
    +    override def genCode(start: String, stop: String, step: String)
    +                        (arr: String, elemType: String)
    +                        (implicit ctx: CodegenContext): String = {
    +      val stepMonths = ctx.freshName("stepMonths")
    +      val stepMicros = ctx.freshName("stepMicros")
    +
    +      lazy val stepScaled = ctx.freshName("stepScaled")
    --- End diff --
    
    Is there any reason to use `lazy` here?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to