Github user wajda commented on a diff in the pull request:
https://github.com/apache/spark/pull/21155#discussion_r197496187
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
---
@@ -1887,6 +1889,402 @@ case class Flatten(child: Expression) extends
UnaryExpression {
override def prettyName: String = "flatten"
}
+@ExpressionDescription(
+ usage = """
+ _FUNC_(start, stop, step) - Generates an array of elements from start
to stop (inclusive),
+ incrementing by step. The type of the returned elements is the same
as the type of argument
+ expressions.
+
+ Supported types are: byte, short, integer, long, date, timestamp.
+
+ The start and stop expressions must resolve to the same type.
+ If start and stop expressions resolve to the 'date' or 'timestamp'
type
+ then the step expression must resolve to the 'interval' type,
otherwise to the same type
+ as the start and stop expressions.
+ """,
+ arguments = """
+ Arguments:
+ * start - an expression. The start of the range.
+ * stop - an expression. The end the range (inclusive).
+ * step - an optional expression. The step of the range.
+ By default step is 1 if start is less than or equal to stop,
otherwise -1.
+ For the temporal sequences it's 1 day and -1 day respectively.
+ If start is greater than stop then the step must be negative,
and vice versa.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(1, 5);
+ [1, 2, 3, 4, 5]
+ > SELECT _FUNC_(5, 1);
+ [5, 4, 3, 2, 1]
+ > SELECT _FUNC_(to_date('2018-01-01'), to_date('2018-03-01'),
interval 1 month);
+ [2018-01-01, 2018-02-01, 2018-03-01]
+ """,
+ since = "2.4.0"
+)
+case class Sequence(
+ start: Expression,
+ stop: Expression,
+ stepOpt: Option[Expression],
+ timeZoneId: Option[String] = None)
+ extends Expression
+ with TimeZoneAwareExpression {
+
+ import Sequence._
+
+ def this(start: Expression, stop: Expression) =
+ this(start, stop, None, None)
+
+ def this(start: Expression, stop: Expression, step: Expression) =
+ this(start, stop, Some(step), None)
+
+ override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
+ copy(timeZoneId = Some(timeZoneId))
+
+ override def children: Seq[Expression] = Seq(start, stop) ++ stepOpt
+
+ override def foldable: Boolean = children.forall(_.foldable)
+
+ override def nullable: Boolean = children.exists(_.nullable)
+
+ override lazy val dataType: ArrayType = ArrayType(start.dataType,
containsNull = false)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val startType = start.dataType
+ def stepType = stepOpt.get.dataType
+ val typesCorrect =
+ startType.sameType(stop.dataType) &&
+ (startType match {
+ case TimestampType | DateType =>
+ stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepType)
+ case _: IntegralType =>
+ stepOpt.isEmpty || stepType.sameType(startType)
+ case _ => false
+ })
+
+ if (typesCorrect) {
+ TypeCheckResult.TypeCheckSuccess
+ } else {
+ TypeCheckResult.TypeCheckFailure(
+ s"$prettyName only supports integral, timestamp or date types")
+ }
+ }
+
+ def coercibleChildren: Seq[Expression] = children.filter(_.dataType !=
CalendarIntervalType)
+
+ def castChildrenTo(widerType: DataType): Expression = Sequence(
+ Cast(start, widerType),
+ Cast(stop, widerType),
+ stepOpt.map(step => if (step.dataType != CalendarIntervalType)
Cast(step, widerType) else step),
+ timeZoneId)
+
+ private lazy val impl: SequenceImpl = dataType.elementType match {
+ case iType: IntegralType =>
+ type T = iType.InternalType
+ val ct = ClassTag[T](iType.tag.mirror.runtimeClass(iType.tag.tpe))
+ new IntegralSequenceImpl(iType)(ct, iType.integral)
+
+ case TimestampType =>
+ new TemporalSequenceImpl[Long](LongType, 1, identity, timeZone)
+
+ case DateType =>
+ new TemporalSequenceImpl[Int](IntegerType, MICROS_PER_DAY, _.toInt,
timeZone)
+ }
+
+ override def eval(input: InternalRow): Any = {
+ val startVal = start.eval(input)
+ if (startVal == null) return null
+ val stopVal = stop.eval(input)
+ if (stopVal == null) return null
+ val stepVal =
stepOpt.map(_.eval(input)).getOrElse(impl.defaultStep(startVal, stopVal))
+ if (stepVal == null) return null
+
+ ArrayData.toArrayData(impl.eval(startVal, stopVal, stepVal))
+ }
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode):
ExprCode = {
+ val startGen = start.genCode(ctx)
+ val stopGen = stop.genCode(ctx)
+ val stepGen = stepOpt.map(_.genCode(ctx)).getOrElse(
+ impl.defaultStep.genCode(ctx, startGen, stopGen))
+
+ val resultType = CodeGenerator.javaType(dataType)
+ val resultCode = {
+ val arr = ctx.freshName("arr")
+ val arrElemType = CodeGenerator.javaType(dataType.elementType)
+ s"""
+ |final $arrElemType[] $arr = null;
+ |${impl.genCode(ctx, startGen.value, stopGen.value,
stepGen.value, arr, arrElemType)}
+ |${ev.value} = UnsafeArrayData.fromPrimitiveArray($arr);
+ """.stripMargin
+ }
+
+ if (nullable) {
+ val nullSafeEval =
+ startGen.code + ctx.nullSafeExec(start.nullable, startGen.isNull) {
+ stopGen.code + ctx.nullSafeExec(stop.nullable, stopGen.isNull) {
+ stepGen.code + ctx.nullSafeExec(stepOpt.exists(_.nullable),
stepGen.isNull) {
+ s"""
+ |${ev.isNull} = false;
+ |$resultCode
+ """.stripMargin
+ }
+ }
+ }
+ ev.copy(code =
+ s"""
+ |boolean ${ev.isNull} = true;
+ |$resultType ${ev.value} = null;
+ |$nullSafeEval
+ """.stripMargin)
+
+ } else {
+ ev.copy(code =
+ s"""
+ |boolean ${ev.isNull} = false;
+ |${startGen.code}
+ |${stopGen.code}
+ |${stepGen.code}
+ |$resultType ${ev.value} = null;
+ |$resultCode
+ """.stripMargin,
+ isNull = FalseLiteral)
+ }
+ }
+}
+
+object Sequence {
+
+ private type LessThanOrEqualFn = (Any, Any) => Boolean
+
+ private class DefaultStep(lteq: LessThanOrEqualFn, stepType: DataType,
one: Any) {
+ private val negativeOne = UnaryMinus(Literal(one)).eval()
+
+ def apply(start: Any, stop: Any): Any = {
+ if (lteq(start, stop)) one else negativeOne
+ }
+
+ def genCode(ctx: CodegenContext, startGen: ExprCode, stopGen:
ExprCode): ExprCode = {
+ val Seq(oneVal, negativeOneVal) = Seq(one,
negativeOne).map(Literal(_).genCode(ctx).value)
+ ExprCode.forNonNullValue(JavaCode.expression(
+ s"${startGen.value} <= ${stopGen.value} ? $oneVal :
$negativeOneVal",
+ stepType))
+ }
+ }
+
+ private trait SequenceImpl {
+ def eval(start: Any, stop: Any, step: Any): Any
+
+ def genCode(
+ ctx: CodegenContext,
+ start: String,
+ stop: String,
+ step: String,
+ arr: String,
+ elemType: String): String
+
+ val defaultStep: DefaultStep
+ }
+
+ private class IntegralSequenceImpl[T: ClassTag]
+ (elemType: IntegralType)(implicit num: Integral[T]) extends
SequenceImpl {
+
+ override val defaultStep: DefaultStep = new DefaultStep(
+ (elemType.ordering.lteq _).asInstanceOf[LessThanOrEqualFn],
+ elemType,
+ num.one)
+
+ override def eval(input1: Any, input2: Any, input3: Any): Array[T] = {
+ import num._
+
+ val start = input1.asInstanceOf[T]
+ val stop = input2.asInstanceOf[T]
+ val step = input3.asInstanceOf[T]
+
+ var i: Int = getSequenceLength(start, stop, step)
+ val arr = new Array[T](i)
+ while (i > 0) {
+ i -= 1
+ arr(i) = start + step * num.fromInt(i)
+ }
+ arr
+ }
+
+ override def genCode(
+ ctx: CodegenContext,
+ start: String,
+ stop: String,
+ step: String,
+ arr: String,
+ elemType: String): String = {
+ val i = ctx.freshName("i")
+ s"""
+ |${genSequenceLengthCode(ctx, start, stop, step, i)}
+ |$arr = new $elemType[$i];
+ |while ($i > 0) {
+ | $i--;
+ | $arr[$i] = ($elemType) ($start + $step * $i);
+ |}
+ """.stripMargin
+ }
+ }
+
+ private class TemporalSequenceImpl[T: ClassTag]
+ (dt: IntegralType, scale: Long, fromLong: Long => T, timeZone:
TimeZone)
+ (implicit num: Integral[T]) extends SequenceImpl {
+
+ override val defaultStep: DefaultStep = new DefaultStep(
+ (dt.ordering.lteq _).asInstanceOf[LessThanOrEqualFn],
+ CalendarIntervalType,
+ new CalendarInterval(0, MICROS_PER_DAY))
+
+ private val backedSequenceImpl = new IntegralSequenceImpl[T](dt)
+ private val microsPerMonth = 28 * CalendarInterval.MICROS_PER_DAY
+
+ override def eval(input1: Any, input2: Any, input3: Any): Array[T] = {
+ val start = input1.asInstanceOf[T]
+ val stop = input2.asInstanceOf[T]
+ val step = input3.asInstanceOf[CalendarInterval]
+ val stepMonths = step.months
+ val stepMicros = step.microseconds
+
+ if (stepMonths == 0) {
+ backedSequenceImpl.eval(start, stop, fromLong(stepMicros / scale))
+
+ } else {
+ // To estimate the resulted array length we need to make
assumptions
+ // about a month length in microseconds
+ val intervalStepInMicros = stepMicros + stepMonths * microsPerMonth
+ val startMicros: Long = num.toLong(start) * scale
+ val stopMicros: Long = num.toLong(stop) * scale
+ val maxEstimatedArrayLength =
+ getSequenceLength(startMicros, stopMicros, intervalStepInMicros)
+
+ val stepSign = if (stopMicros > startMicros) +1 else -1
+ val exclusiveItem = stopMicros + stepSign
+ val arr = new Array[T](maxEstimatedArrayLength)
+ var t = startMicros
+ var i = 0
+
+ while (t < exclusiveItem ^ stepSign < 0) {
+ arr(i) = fromLong(t / scale)
+ t = timestampAddInterval(t, stepMonths, stepMicros, timeZone)
+ i += 1
+ }
+
+ // truncate array to the correct length
+ if (arr.length == i) arr else arr.slice(0, i)
+ }
+ }
+
+ override def genCode(
+ ctx: CodegenContext,
+ start: String,
+ stop: String,
+ step: String,
+ arr: String,
+ elemType: String): String = {
+ val stepMonths = ctx.freshName("stepMonths")
+ val stepMicros = ctx.freshName("stepMicros")
+ val stepScaled = ctx.freshName("stepScaled")
+ val intervalInMicros = ctx.freshName("intervalInMicros")
+ val startMicros = ctx.freshName("startMicros")
+ val stopMicros = ctx.freshName("stopMicros")
+ val arrLength = ctx.freshName("arrLength")
+ val stepSign = ctx.freshName("stepSign")
+ val exclusiveItem = ctx.freshName("exclusiveItem")
+ val t = ctx.freshName("t")
+ val i = ctx.freshName("i")
+ val genTimeZone = ctx.freshName("timeZone")
+
+ val sequenceLengthCode =
+ s"""
+ |final long $intervalInMicros = $stepMicros + $stepMonths *
${microsPerMonth}L;
+ |${genSequenceLengthCode(ctx, startMicros, stopMicros,
intervalInMicros, arrLength)}
+ """.stripMargin
+
+ val timestampAddIntervalCode =
+ s"""
+ |$t =
org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampAddInterval(
+ | $t, $stepMonths, $stepMicros, $genTimeZone);
+ """.stripMargin
+
+ s"""
+ |final int $stepMonths = $step.months;
+ |final long $stepMicros = $step.microseconds;
+ |
+ |if ($stepMonths == 0) {
+ | final $elemType $stepScaled = ($elemType) ($stepMicros /
${scale}L);
+ | ${backedSequenceImpl.genCode(ctx, start, stop, stepScaled,
arr, elemType)};
+ |
+ |} else {
+ | final long $startMicros = $start * ${scale}L;
+ | final long $stopMicros = $stop * ${scale}L;
+ |
+ | $sequenceLengthCode
+ |
+ | final int $stepSign = $stopMicros > $startMicros ? +1 : -1;
+ | final long $exclusiveItem = $stopMicros + $stepSign;
+ | final java.util.TimeZone $genTimeZone =
+ | java.util.TimeZone.getTimeZone("${timeZone.getID}");
+ |
+ | $arr = new $elemType[$arrLength];
+ | long $t = $startMicros;
+ | int $i = 0;
+ |
+ | while ($t < $exclusiveItem ^ $stepSign < 0) {
+ | $arr[$i] = ($elemType) ($t / ${scale}L);
+ | $timestampAddIntervalCode
+ | $i += 1;
+ | }
+ |
+ | if ($arr.length > $i) {
+ | $arr = java.util.Arrays.copyOf($arr, $i);
+ | }
+ |}
+ """.stripMargin
+ }
+ }
+
+ private def getSequenceLength[U](start: U, stop: U, step: U)(implicit
num: Integral[U]): Int = {
+ import num._
+ require(
+ step > num.zero && start <= stop
+ || step < num.zero && start >= stop
+ || step == 0 && start == stop,
+ s"Illegal sequence boundaries: $start to $stop by $step")
+
+ val len = if (start == stop) 1L else 1L + (stop.toLong - start.toLong)
/ step.toLong
+
+ require(len <= Int.MaxValue, s"Too long sequence: $len. Should be <=
${Int.MaxValue}")
+
+ len.toInt
+ }
+
+ private def genSequenceLengthCode(
+ ctx: CodegenContext,
+ start: String,
+ stop: String,
+ step: String,
+ len: String): String = {
+ val longLen = ctx.freshName("longLen")
+ s"""
+ |if (!($step > 0 && $start <= $stop||
+ | $step < 0 && $start >= $stop||
+ | $step == 0 && $start == $stop)) {
--- End diff --
done
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]