Github user mgaido91 commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21040#discussion_r181338128
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 ---
    @@ -287,3 +287,101 @@ case class ArrayContains(left: Expression, right: 
Expression)
     
       override def prettyName: String = "array_contains"
     }
    +
    +
    +/**
    + * Slices an array according to the requested start index and length
    + */
    +// scalastyle:off line.size.limit
    +@ExpressionDescription(
    +  usage = "_FUNC_(a1, a2) - Subsets array x starting from index start (or 
starting from the end if start is negative) with the specified length.",
    +  examples = """
    +    Examples:
    +      > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2);
    +       [2,3]
    +      > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2);
    +       [3,4]
    +  """, since = "2.4.0")
    +// scalastyle:on line.size.limit
    +case class Slice(x: Expression, start: Expression, length: Expression)
    +  extends TernaryExpression with ImplicitCastInputTypes {
    +
    +  override def dataType: DataType = x.dataType
    +
    +  override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, 
IntegerType, IntegerType)
    +
    +  override def nullable: Boolean = children.exists(_.nullable)
    +
    +  override def foldable: Boolean = children.forall(_.foldable)
    +
    +  override def children: Seq[Expression] = Seq(x, start, length)
    +
    +  override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any 
= {
    +    val startInt = startVal.asInstanceOf[Int]
    +    val lengthInt = lengthVal.asInstanceOf[Int]
    +    val arr = xVal.asInstanceOf[ArrayData]
    +    val startIndex = if (startInt == 0) {
    +      throw new RuntimeException(
    +        s"Unexpected value for start in function $prettyName:  SQL array 
indices start at 1.")
    +    } else if (startInt < 0) {
    +      startInt + arr.numElements()
    +    } else {
    +      startInt - 1
    +    }
    +    if (lengthInt < 0) {
    +      throw new RuntimeException(s"Unexpected value for length in function 
$prettyName: " +
    +        s"length must be greater than or equal to 0.")
    +    }
    +    // this can happen if start is negative and its absolute value is 
greater than the
    +    // number of elements in the array
    +    if (startIndex < 0) {
    +      return new GenericArrayData(Array.empty[AnyRef])
    +    }
    +    val elementType = x.dataType.asInstanceOf[ArrayType].elementType
    +    val data = arr.toArray[AnyRef](elementType)
    +    new GenericArrayData(data.slice(startIndex, startIndex + lengthInt))
    +  }
    +
    +  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    +    val elementType = x.dataType.asInstanceOf[ArrayType].elementType
    +    nullSafeCodeGen(ctx, ev, (x, start, length) => {
    +      val arrayClass = classOf[GenericArrayData].getName
    +      val values = ctx.freshName("values")
    +      val i = ctx.freshName("i")
    +      val startIdx = ctx.freshName("startIdx")
    +      val resLength = ctx.freshName("resLength")
    +      val defaultIntValue = 
CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false)
    +      s"""
    +         |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue;
    +         |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue;
    +         |if ($start == 0) {
    +         |  throw new RuntimeException("Unexpected value for start in 
function $prettyName: "
    +         |    + "SQL array indices start at 1.");
    +         |} else if ($start < 0) {
    +         |  $startIdx = $start + $x.numElements();
    +         |} else {
    +         |  // arrays in SQL are 1-based instead of 0-based
    +         |  $startIdx = $start - 1;
    +         |}
    +         |if ($length < 0) {
    +         |  throw new RuntimeException("Unexpected value for length in 
function $prettyName: "
    +         |    + "length must be greater than or equal to 0.");
    +         |} else if ($length > $x.numElements() - $startIdx) {
    +         |  $resLength = $x.numElements() - $startIdx;
    +         |} else {
    +         |  $resLength = $length;
    +         |}
    +         |Object[] $values;
    +         |if ($startIdx < 0) {
    +         |  $values = new Object[0];
    +         |} else {
    +         |  $values = new Object[$resLength];
    +         |  for (int $i = 0; $i < $resLength; $i ++) {
    +         |    $values[$i] = ${CodeGenerator.getValue(x, elementType, s"$i 
+ $startIdx")};
    --- End diff --
    
    My target of coherency was the `CreateArray` operator and the code 
generated in `GenerateSafeProjection`.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to