Github user ueshin commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21155#discussion_r197676498
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 ---
    @@ -2288,6 +2288,401 @@ case class Flatten(child: Expression) extends 
UnaryExpression {
       override def prettyName: String = "flatten"
     }
     
    +@ExpressionDescription(
    +  usage = """
    +    _FUNC_(start, stop, step) - Generates an array of elements from start 
to stop (inclusive),
    +      incrementing by step. The type of the returned elements is the same 
as the type of argument
    +      expressions.
    +
    +      Supported types are: byte, short, integer, long, date, timestamp.
    +
    +      The start and stop expressions must resolve to the same type.
    +      If start and stop expressions resolve to the 'date' or 'timestamp' 
type
    +      then the step expression must resolve to the 'interval' type, 
otherwise to the same type
    +      as the start and stop expressions.
    +  """,
    +  arguments = """
    +    Arguments:
    +      * start - an expression. The start of the range.
    +      * stop - an expression. The end the range (inclusive).
    +      * step - an optional expression. The step of the range.
    +          By default step is 1 if start is less than or equal to stop, 
otherwise -1.
    +          For the temporal sequences it's 1 day and -1 day respectively.
    +          If start is greater than stop then the step must be negative, 
and vice versa.
    +  """,
    +  examples = """
    +    Examples:
    +      > SELECT _FUNC_(1, 5);
    +       [1, 2, 3, 4, 5]
    +      > SELECT _FUNC_(5, 1);
    +       [5, 4, 3, 2, 1]
    +      > SELECT _FUNC_(to_date('2018-01-01'), to_date('2018-03-01'), 
interval 1 month);
    +       [2018-01-01, 2018-02-01, 2018-03-01]
    +  """,
    +  since = "2.4.0"
    +)
    +case class Sequence(
    +    start: Expression,
    +    stop: Expression,
    +    stepOpt: Option[Expression],
    +    timeZoneId: Option[String] = None)
    +  extends Expression
    +  with TimeZoneAwareExpression {
    +
    +  import Sequence._
    +
    +  def this(start: Expression, stop: Expression) =
    +    this(start, stop, None, None)
    +
    +  def this(start: Expression, stop: Expression, step: Expression) =
    +    this(start, stop, Some(step), None)
    +
    +  override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
    +    copy(timeZoneId = Some(timeZoneId))
    +
    +  override def children: Seq[Expression] = Seq(start, stop) ++ stepOpt
    +
    +  override def foldable: Boolean = children.forall(_.foldable)
    +
    +  override def nullable: Boolean = children.exists(_.nullable)
    +
    +  override lazy val dataType: ArrayType = ArrayType(start.dataType, 
containsNull = false)
    +
    +  override def checkInputDataTypes(): TypeCheckResult = {
    +    val startType = start.dataType
    +    def stepType = stepOpt.get.dataType
    +    val typesCorrect =
    +      startType.sameType(stop.dataType) &&
    +        (startType match {
    +          case TimestampType | DateType =>
    +            stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepType)
    +          case _: IntegralType =>
    +            stepOpt.isEmpty || stepType.sameType(startType)
    +          case _ => false
    +        })
    +
    +    if (typesCorrect) {
    +      TypeCheckResult.TypeCheckSuccess
    +    } else {
    +      TypeCheckResult.TypeCheckFailure(
    +        s"$prettyName only supports integral, timestamp or date types")
    +    }
    +  }
    +
    +  def coercibleChildren: Seq[Expression] = children.filter(_.dataType != 
CalendarIntervalType)
    +
    +  def castChildrenTo(widerType: DataType): Expression = Sequence(
    +    Cast(start, widerType),
    +    Cast(stop, widerType),
    +    stepOpt.map(step => if (step.dataType != CalendarIntervalType) 
Cast(step, widerType) else step),
    +    timeZoneId)
    +
    +  private lazy val impl: SequenceImpl = dataType.elementType match {
    +    case iType: IntegralType =>
    +      type T = iType.InternalType
    +      val ct = ClassTag[T](iType.tag.mirror.runtimeClass(iType.tag.tpe))
    +      new IntegralSequenceImpl(iType)(ct, iType.integral)
    +
    +    case TimestampType =>
    +      new TemporalSequenceImpl[Long](LongType, 1, identity, timeZone)
    +
    +    case DateType =>
    +      new TemporalSequenceImpl[Int](IntegerType, MICROS_PER_DAY, _.toInt, 
timeZone)
    +  }
    +
    +  override def eval(input: InternalRow): Any = {
    +    val startVal = start.eval(input)
    +    if (startVal == null) return null
    +    val stopVal = stop.eval(input)
    +    if (stopVal == null) return null
    +    val stepVal = 
stepOpt.map(_.eval(input)).getOrElse(impl.defaultStep(startVal, stopVal))
    +    if (stepVal == null) return null
    +
    +    ArrayData.toArrayData(impl.eval(startVal, stopVal, stepVal))
    +  }
    +
    +  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
    +    val startGen = start.genCode(ctx)
    +    val stopGen = stop.genCode(ctx)
    +    val stepGen = stepOpt.map(_.genCode(ctx)).getOrElse(
    +      impl.defaultStep.genCode(ctx, startGen, stopGen))
    +
    +    val resultType = CodeGenerator.javaType(dataType)
    +    val resultCode = {
    +      val arr = ctx.freshName("arr")
    +      val arrElemType = CodeGenerator.javaType(dataType.elementType)
    +      s"""
    +         |final $arrElemType[] $arr = null;
    +         |${impl.genCode(ctx, startGen.value, stopGen.value, 
stepGen.value, arr, arrElemType)}
    +         |${ev.value} = UnsafeArrayData.fromPrimitiveArray($arr);
    +       """.stripMargin
    +    }
    +
    +    if (nullable) {
    +      val nullSafeEval =
    +        startGen.code + ctx.nullSafeExec(start.nullable, startGen.isNull) {
    +          stopGen.code + ctx.nullSafeExec(stop.nullable, stopGen.isNull) {
    +            stepGen.code + ctx.nullSafeExec(stepOpt.exists(_.nullable), 
stepGen.isNull) {
    +              s"""
    +                 |${ev.isNull} = false;
    +                 |$resultCode
    +               """.stripMargin
    +            }
    +          }
    +        }
    +      ev.copy(code =
    +        code"""
    +           |boolean ${ev.isNull} = true;
    +           |$resultType ${ev.value} = null;
    +           |$nullSafeEval
    +         """.stripMargin)
    +
    +    } else {
    +      ev.copy(code =
    +        code"""
    +           |${startGen.code}
    +           |${stopGen.code}
    +           |${stepGen.code}
    +           |$resultType ${ev.value} = null;
    +           |$resultCode
    +         """.stripMargin,
    +        isNull = FalseLiteral)
    +    }
    +  }
    +}
    +
    +object Sequence {
    +
    +  private type LessThanOrEqualFn = (Any, Any) => Boolean
    +
    +  private class DefaultStep(lteq: LessThanOrEqualFn, stepType: DataType, 
one: Any) {
    +    private val negativeOne = UnaryMinus(Literal(one)).eval()
    +
    +    def apply(start: Any, stop: Any): Any = {
    +      if (lteq(start, stop)) one else negativeOne
    +    }
    +
    +    def genCode(ctx: CodegenContext, startGen: ExprCode, stopGen: 
ExprCode): ExprCode = {
    +      val Seq(oneVal, negativeOneVal) = Seq(one, 
negativeOne).map(Literal(_).genCode(ctx).value)
    +      ExprCode.forNonNullValue(JavaCode.expression(
    +        s"${startGen.value} <= ${stopGen.value} ? $oneVal : 
$negativeOneVal",
    +        stepType))
    +    }
    +  }
    +
    +  private trait SequenceImpl {
    +    def eval(start: Any, stop: Any, step: Any): Any
    +
    +    def genCode(
    +        ctx: CodegenContext,
    +        start: String,
    +        stop: String,
    +        step: String,
    +        arr: String,
    +        elemType: String): String
    +
    +    val defaultStep: DefaultStep
    +  }
    +
    +  private class IntegralSequenceImpl[T: ClassTag]
    +    (elemType: IntegralType)(implicit num: Integral[T]) extends 
SequenceImpl {
    +
    +    override val defaultStep: DefaultStep = new DefaultStep(
    +      (elemType.ordering.lteq _).asInstanceOf[LessThanOrEqualFn],
    +      elemType,
    +      num.one)
    +
    +    override def eval(input1: Any, input2: Any, input3: Any): Array[T] = {
    +      import num._
    +
    +      val start = input1.asInstanceOf[T]
    +      val stop = input2.asInstanceOf[T]
    +      val step = input3.asInstanceOf[T]
    +
    +      var i: Int = getSequenceLength(start, stop, step)
    +      val arr = new Array[T](i)
    +      while (i > 0) {
    +        i -= 1
    +        arr(i) = start + step * num.fromInt(i)
    +      }
    +      arr
    +    }
    +
    +    override def genCode(
    +        ctx: CodegenContext,
    +        start: String,
    +        stop: String,
    +        step: String,
    +        arr: String,
    +        elemType: String): String = {
    +      val i = ctx.freshName("i")
    +      s"""
    +         |${genSequenceLengthCode(ctx, start, stop, step, i)}
    +         |$arr = new $elemType[$i];
    +         |while ($i > 0) {
    +         |  $i--;
    +         |  $arr[$i] = ($elemType) ($start + $step * $i);
    +         |}
    +         """.stripMargin
    +    }
    +  }
    +
    +  private class TemporalSequenceImpl[T: ClassTag]
    +      (dt: IntegralType, scale: Long, fromLong: Long => T, timeZone: 
TimeZone)
    +      (implicit num: Integral[T]) extends SequenceImpl {
    +
    +    override val defaultStep: DefaultStep = new DefaultStep(
    +      (dt.ordering.lteq _).asInstanceOf[LessThanOrEqualFn],
    +      CalendarIntervalType,
    +      new CalendarInterval(0, MICROS_PER_DAY))
    +
    +    private val backedSequenceImpl = new IntegralSequenceImpl[T](dt)
    +    private val microsPerMonth = 28 * CalendarInterval.MICROS_PER_DAY
    +
    +    override def eval(input1: Any, input2: Any, input3: Any): Array[T] = {
    +      val start = input1.asInstanceOf[T]
    +      val stop = input2.asInstanceOf[T]
    +      val step = input3.asInstanceOf[CalendarInterval]
    +      val stepMonths = step.months
    +      val stepMicros = step.microseconds
    +
    +      if (stepMonths == 0) {
    +        backedSequenceImpl.eval(start, stop, fromLong(stepMicros / scale))
    +
    +      } else {
    +        // To estimate the resulted array length we need to make 
assumptions
    +        // about a month length in microseconds
    +        val intervalStepInMicros = stepMicros + stepMonths * microsPerMonth
    +        val startMicros: Long = num.toLong(start) * scale
    +        val stopMicros: Long = num.toLong(stop) * scale
    +        val maxEstimatedArrayLength =
    +          getSequenceLength(startMicros, stopMicros, intervalStepInMicros)
    +
    +        val stepSign = if (stopMicros > startMicros) +1 else -1
    +        val exclusiveItem = stopMicros + stepSign
    +        val arr = new Array[T](maxEstimatedArrayLength)
    +        var t = startMicros
    +        var i = 0
    +
    +        while (t < exclusiveItem ^ stepSign < 0) {
    +          arr(i) = fromLong(t / scale)
    +          t = timestampAddInterval(t, stepMonths, stepMicros, timeZone)
    +          i += 1
    +        }
    +
    +        // truncate array to the correct length
    +        if (arr.length == i) arr else arr.slice(0, i)
    +      }
    +    }
    +
    +    override def genCode(
    +        ctx: CodegenContext,
    +        start: String,
    +        stop: String,
    +        step: String,
    +        arr: String,
    +        elemType: String): String = {
    +      val stepMonths = ctx.freshName("stepMonths")
    +      val stepMicros = ctx.freshName("stepMicros")
    +      val stepScaled = ctx.freshName("stepScaled")
    +      val intervalInMicros = ctx.freshName("intervalInMicros")
    +      val startMicros = ctx.freshName("startMicros")
    +      val stopMicros = ctx.freshName("stopMicros")
    +      val arrLength = ctx.freshName("arrLength")
    +      val stepSign = ctx.freshName("stepSign")
    +      val exclusiveItem = ctx.freshName("exclusiveItem")
    +      val t = ctx.freshName("t")
    +      val i = ctx.freshName("i")
    +      val genTimeZone = ctx.addReferenceObj("timeZone", timeZone, 
classOf[TimeZone].getName)
    +
    +      val sequenceLengthCode =
    +        s"""
    +           |final long $intervalInMicros = $stepMicros + $stepMonths * 
${microsPerMonth}L;
    +           |${genSequenceLengthCode(ctx, startMicros, stopMicros, 
intervalInMicros, arrLength)}
    +          """.stripMargin
    +
    +      val timestampAddIntervalCode =
    +        s"""
    +           |$t = 
org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampAddInterval(
    +           |  $t, $stepMonths, $stepMicros, $genTimeZone);
    +          """.stripMargin
    +
    +      s"""
    +         |final int $stepMonths = $step.months;
    +         |final long $stepMicros = $step.microseconds;
    +         |
    +         |if ($stepMonths == 0) {
    +         |  final $elemType $stepScaled = ($elemType) ($stepMicros / 
${scale}L);
    +         |  ${backedSequenceImpl.genCode(ctx, start, stop, stepScaled, 
arr, elemType)};
    +         |
    +         |} else {
    +         |  final long $startMicros = $start * ${scale}L;
    +         |  final long $stopMicros = $stop * ${scale}L;
    +         |
    +         |  $sequenceLengthCode
    +         |
    +         |  final int $stepSign = $stopMicros > $startMicros ? +1 : -1;
    +         |  final long $exclusiveItem = $stopMicros + $stepSign;
    +         |
    +         |  $arr = new $elemType[$arrLength];
    +         |  long $t = $startMicros;
    +         |  int $i = 0;
    +         |
    +         |  while ($t < $exclusiveItem ^ $stepSign < 0) {
    +         |    $arr[$i] = ($elemType) ($t / ${scale}L);
    +         |    $timestampAddIntervalCode
    +         |    $i += 1;
    +         |  }
    +         |
    +         |  if ($arr.length > $i) {
    +         |    $arr = java.util.Arrays.copyOf($arr, $i);
    +         |  }
    +         |}
    +         """.stripMargin
    +    }
    +  }
    +
    +  private def getSequenceLength[U](start: U, stop: U, step: U)(implicit 
num: Integral[U]): Int = {
    +    import num._
    +    require(
    +      (step > num.zero && start <= stop)
    +        || (step < num.zero && start >= stop)
    +        || (step == 0 && start == stop),
    +      s"Illegal sequence boundaries: $start to $stop by $step")
    +
    +    val len = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) 
/ step.toLong
    +
    +    require(
    +      len <= MAX_ROUNDED_ARRAY_LENGTH,
    +      s"Too long sequence: $len. Should be <= $MAX_ROUNDED_ARRAY_LENGTH")
    +
    +    len.toInt
    +  }
    +
    +  private def genSequenceLengthCode(
    +      ctx: CodegenContext,
    +      start: String,
    +      stop: String,
    +      step: String,
    +      len: String): String = {
    +    val longLen = ctx.freshName("longLen")
    +    s"""
    +       |if (!(($step > 0 && $start <= $stop) ||
    +       |  ($step < 0 && $start >= $stop) ||
    +       |  ($step == 0 && $start == $stop))) {
    +       |  throw new IllegalArgumentException(
    +       |    "Illegal sequence boundaries: " + $start + " to " + $stop + " 
by " + $step);
    +       |}
    +       |long $longLen = $stop == $start ? 1L : 1L + ((long) $stop - 
$start) / $step;
    +       |if ($longLen > Integer.MAX_VALUE) {
    --- End diff --
    
    `MAX_ROUNDED_ARRAY_LENGTH` instead of `Integer.MAX_VALUE`?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to