Repository: spark
Updated Branches:
  refs/heads/master f06528015 -> e35ad3cad


[SPARK-23930][SQL] Add slice function

## What changes were proposed in this pull request?

The PR add the `slice` function. The behavior of the function is based on 
Presto's one.

The function slices an array according to the requested start index and length.

## How was this patch tested?

added UTs

Author: Marco Gaido <[email protected]>

Closes #21040 from mgaido91/SPARK-23930.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/e35ad3ca
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/e35ad3ca
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/e35ad3ca

Branch: refs/heads/master
Commit: e35ad3caddeaa4b0d4c8524dcfb9e9f56dc7fe3d
Parents: f065280
Author: Marco Gaido <[email protected]>
Authored: Mon May 7 16:57:37 2018 +0900
Committer: Takuya UESHIN <[email protected]>
Committed: Mon May 7 16:57:37 2018 +0900

----------------------------------------------------------------------
 python/pyspark/sql/functions.py                 |  13 ++
 .../catalyst/analysis/FunctionRegistry.scala    |   1 +
 .../expressions/codegen/CodeGenerator.scala     |  34 ++++
 .../expressions/collectionOperations.scala      | 163 ++++++++++++++-----
 .../CollectionExpressionsSuite.scala            |  28 ++++
 .../expressions/ExpressionEvalHelper.scala      |   6 +
 .../expressions/ObjectExpressionsSuite.scala    |   1 -
 .../scala/org/apache/spark/sql/functions.scala  |  10 ++
 .../spark/sql/DataFrameFunctionsSuite.scala     |  16 ++
 9 files changed, 233 insertions(+), 39 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/e35ad3ca/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index bd55b5f..ac3c797 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1834,6 +1834,19 @@ def array_contains(col, value):
     return Column(sc._jvm.functions.array_contains(_to_java_column(col), 
value))
 
 
+@since(2.4)
+def slice(x, start, length):
+    """
+    Collection function: returns an array containing  all the elements in `x` 
from index `start`
+    (or starting from the end if `start` is negative) with the specified 
`length`.
+    >>> df = spark.createDataFrame([([1, 2, 3],), ([4, 5],)], ['x'])
+    >>> df.select(slice(df.x, 2, 2).alias("sliced")).collect()
+    [Row(sliced=[2, 3]), Row(sliced=[5])]
+    """
+    sc = SparkContext._active_spark_context
+    return Column(sc._jvm.functions.slice(_to_java_column(x), start, length))
+
+
 @ignore_unicode_prefix
 @since(2.4)
 def array_join(col, delimiter, null_replacement=None):

http://git-wip-us.apache.org/repos/asf/spark/blob/e35ad3ca/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 01776b8..87b0911 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -410,6 +410,7 @@ object FunctionRegistry {
     expression[MapKeys]("map_keys"),
     expression[MapValues]("map_values"),
     expression[Size]("size"),
+    expression[Slice]("slice"),
     expression[Size]("cardinality"),
     expression[SortArray]("sort_array"),
     expression[ArrayMin]("array_min"),

http://git-wip-us.apache.org/repos/asf/spark/blob/e35ad3ca/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index cf0a91f..4dda525 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -42,6 +42,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.Platform
+import org.apache.spark.unsafe.array.ByteArrayMethods
 import org.apache.spark.unsafe.types._
 import org.apache.spark.util.{ParentClassLoader, Utils}
 
@@ -731,6 +732,39 @@ class CodegenContext {
   }
 
   /**
+   * Generates code creating a [[UnsafeArrayData]].
+   *
+   * @param arrayName name of the array to create
+   * @param numElements code representing the number of elements the array 
should contain
+   * @param elementType data type of the elements in the array
+   * @param additionalErrorMessage string to include in the error message
+   */
+  def createUnsafeArray(
+      arrayName: String,
+      numElements: String,
+      elementType: DataType,
+      additionalErrorMessage: String): String = {
+    val arraySize = freshName("size")
+    val arrayBytes = freshName("arrayBytes")
+
+    s"""
+       |long $arraySize = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
+       |  $numElements,
+       |  ${elementType.defaultSize});
+       |if ($arraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
+       |  throw new RuntimeException("Unsuccessful try create array with " + 
$arraySize +
+       |    " bytes of data due to exceeding the limit " +
+       |    "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} bytes for 
UnsafeArrayData." +
+       |    "$additionalErrorMessage");
+       |}
+       |byte[] $arrayBytes = new byte[(int)$arraySize];
+       |UnsafeArrayData $arrayName = new UnsafeArrayData();
+       |Platform.putLong($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, 
$numElements);
+       |$arrayName.pointTo($arrayBytes, ${Platform.BYTE_ARRAY_OFFSET}, 
(int)$arraySize);
+      """.stripMargin
+  }
+
+  /**
    * Generates code to do null safe execution, i.e. only execute the code when 
the input is not
    * null by adding null check if necessary.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/e35ad3ca/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 23c09bc..12b9ab2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -24,7 +24,6 @@ import 
org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, 
MapData, TypeUtils}
 import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.Platform
 import org.apache.spark.unsafe.array.ByteArrayMethods
 import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
 
@@ -531,6 +530,129 @@ case class ArrayContains(left: Expression, right: 
Expression)
 }
 
 /**
+ * Slices an array according to the requested start index and length
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+  usage = "_FUNC_(x, start, length) - Subsets array x starting from index 
start (or starting from the end if start is negative) with the specified 
length.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2);
+       [2,3]
+      > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2);
+       [3,4]
+  """, since = "2.4.0")
+// scalastyle:on line.size.limit
+case class Slice(x: Expression, start: Expression, length: Expression)
+  extends TernaryExpression with ImplicitCastInputTypes {
+
+  override def dataType: DataType = x.dataType
+
+  override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, 
IntegerType)
+
+  override def children: Seq[Expression] = Seq(x, start, length)
+
+  lazy val elementType: DataType = 
x.dataType.asInstanceOf[ArrayType].elementType
+
+  override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = {
+    val startInt = startVal.asInstanceOf[Int]
+    val lengthInt = lengthVal.asInstanceOf[Int]
+    val arr = xVal.asInstanceOf[ArrayData]
+    val startIndex = if (startInt == 0) {
+      throw new RuntimeException(
+        s"Unexpected value for start in function $prettyName: SQL array 
indices start at 1.")
+    } else if (startInt < 0) {
+      startInt + arr.numElements()
+    } else {
+      startInt - 1
+    }
+    if (lengthInt < 0) {
+      throw new RuntimeException(s"Unexpected value for length in function 
$prettyName: " +
+        "length must be greater than or equal to 0.")
+    }
+    // startIndex can be negative if start is negative and its absolute value 
is greater than the
+    // number of elements in the array
+    if (startIndex < 0 || startIndex >= arr.numElements()) {
+      return new GenericArrayData(Array.empty[AnyRef])
+    }
+    val data = arr.toSeq[AnyRef](elementType)
+    new GenericArrayData(data.slice(startIndex, startIndex + lengthInt))
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    nullSafeCodeGen(ctx, ev, (x, start, length) => {
+      val startIdx = ctx.freshName("startIdx")
+      val resLength = ctx.freshName("resLength")
+      val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, 
false)
+      s"""
+         |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue;
+         |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue;
+         |if ($start == 0) {
+         |  throw new RuntimeException("Unexpected value for start in function 
$prettyName: "
+         |    + "SQL array indices start at 1.");
+         |} else if ($start < 0) {
+         |  $startIdx = $start + $x.numElements();
+         |} else {
+         |  // arrays in SQL are 1-based instead of 0-based
+         |  $startIdx = $start - 1;
+         |}
+         |if ($length < 0) {
+         |  throw new RuntimeException("Unexpected value for length in 
function $prettyName: "
+         |    + "length must be greater than or equal to 0.");
+         |} else if ($length > $x.numElements() - $startIdx) {
+         |  $resLength = $x.numElements() - $startIdx;
+         |} else {
+         |  $resLength = $length;
+         |}
+         |${genCodeForResult(ctx, ev, x, startIdx, resLength)}
+       """.stripMargin
+    })
+  }
+
+  def genCodeForResult(
+      ctx: CodegenContext,
+      ev: ExprCode,
+      inputArray: String,
+      startIdx: String,
+      resLength: String): String = {
+    val values = ctx.freshName("values")
+    val i = ctx.freshName("i")
+    val getValue = CodeGenerator.getValue(inputArray, elementType, s"$i + 
$startIdx")
+    if (!CodeGenerator.isPrimitiveType(elementType)) {
+      val arrayClass = classOf[GenericArrayData].getName
+      s"""
+         |Object[] $values;
+         |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) {
+         |  $values = new Object[0];
+         |} else {
+         |  $values = new Object[$resLength];
+         |  for (int $i = 0; $i < $resLength; $i ++) {
+         |    $values[$i] = $getValue;
+         |  }
+         |}
+         |${ev.value} = new $arrayClass($values);
+       """.stripMargin
+    } else {
+      val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
+      s"""
+         |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) {
+         |  $resLength = 0;
+         |}
+         |${ctx.createUnsafeArray(values, resLength, elementType, s" 
$prettyName failed.")}
+         |for (int $i = 0; $i < $resLength; $i ++) {
+         |  if ($inputArray.isNullAt($i + $startIdx)) {
+         |    $values.setNullAt($i);
+         |  } else {
+         |    $values.set$primitiveValueTypeName($i, $getValue);
+         |  }
+         |}
+         |${ev.value} = $values;
+       """.stripMargin
+    }
+  }
+}
+
+/**
  * Creates a String containing all the elements of the input array separated 
by the delimiter.
  */
 @ExpressionDescription(
@@ -1127,24 +1249,11 @@ case class Concat(children: Seq[Expression]) extends 
Expression {
   }
 
   private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: 
DataType): String = {
-    val arrayName = ctx.freshName("array")
-    val arraySizeName = ctx.freshName("size")
     val counter = ctx.freshName("counter")
     val arrayData = ctx.freshName("arrayData")
 
     val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)
 
-    val unsafeArraySizeInBytes = s"""
-      |long $arraySizeName = 
UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
-      |  $numElemName,
-      |  ${elementType.defaultSize});
-      |if ($arraySizeName > $MAX_ARRAY_LENGTH) {
-      |  throw new RuntimeException("Unsuccessful try to concat arrays with " 
+ $arraySizeName +
-      |    " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes" 
+
-      |    " for UnsafeArrayData.");
-      |}
-      """.stripMargin
-    val baseOffset = Platform.BYTE_ARRAY_OFFSET
     val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
 
     s"""
@@ -1152,11 +1261,7 @@ case class Concat(children: Seq[Expression]) extends 
Expression {
        |  public ArrayData concat($javaType[] args) {
        |    ${nullArgumentProtection()}
        |    $numElemCode
-       |    $unsafeArraySizeInBytes
-       |    byte[] $arrayName = new byte[(int)$arraySizeName];
-       |    UnsafeArrayData $arrayData = new UnsafeArrayData();
-       |    Platform.putLong($arrayName, $baseOffset, $numElemName);
-       |    $arrayData.pointTo($arrayName, $baseOffset, (int)$arraySizeName);
+       |    ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" 
$prettyName failed.")}
        |    int $counter = 0;
        |    for (int y = 0; y < ${children.length}; y++) {
        |      for (int z = 0; z < args[y].numElements(); z++) {
@@ -1308,34 +1413,16 @@ case class Flatten(child: Expression) extends 
UnaryExpression {
       ctx: CodegenContext,
       childVariableName: String,
       arrayDataName: String): String = {
-    val arrayName = ctx.freshName("array")
-    val arraySizeName = ctx.freshName("size")
     val counter = ctx.freshName("counter")
     val tempArrayDataName = ctx.freshName("tempArrayData")
 
     val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, 
childVariableName)
 
-    val unsafeArraySizeInBytes = s"""
-      |long $arraySizeName = 
UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
-      |  $numElemName,
-      |  ${elementType.defaultSize});
-      |if ($arraySizeName > $MAX_ARRAY_LENGTH) {
-      |  throw new RuntimeException("Unsuccessful try to flatten an array of 
arrays with " +
-      |    $arraySizeName + " bytes of data due to exceeding the limit 
$MAX_ARRAY_LENGTH" +
-      |    " bytes for UnsafeArrayData.");
-      |}
-      """.stripMargin
-    val baseOffset = Platform.BYTE_ARRAY_OFFSET
-
     val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
 
     s"""
     |$numElemCode
-    |$unsafeArraySizeInBytes
-    |byte[] $arrayName = new byte[(int)$arraySizeName];
-    |UnsafeArrayData $tempArrayDataName = new UnsafeArrayData();
-    |Platform.putLong($arrayName, $baseOffset, $numElemName);
-    |$tempArrayDataName.pointTo($arrayName, $baseOffset, (int)$arraySizeName);
+    |${ctx.createUnsafeArray(tempArrayDataName, numElemName, elementType, s" 
$prettyName failed.")}
     |int $counter = 0;
     |for (int k = 0; k < $childVariableName.numElements(); k++) {
     |  ArrayData arr = $childVariableName.getArray(k);

http://git-wip-us.apache.org/repos/asf/spark/blob/e35ad3ca/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index 749374f..a2851d0 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -136,6 +136,34 @@ class CollectionExpressionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
     checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
   }
 
+  test("Slice") {
+    val a0 = Literal.create(Seq(1, 2, 3, 4, 5, 6), ArrayType(IntegerType))
+    val a1 = Literal.create(Seq[String]("a", "b", "c", "d"), 
ArrayType(StringType))
+    val a2 = Literal.create(Seq[String]("", null, "a", "b"), 
ArrayType(StringType))
+    val a3 = Literal.create(Seq(1, 2, null, 4), ArrayType(IntegerType))
+
+    checkEvaluation(Slice(a0, Literal(1), Literal(2)), Seq(1, 2))
+    checkEvaluation(Slice(a0, Literal(-3), Literal(2)), Seq(4, 5))
+    checkEvaluation(Slice(a0, Literal(4), Literal(10)), Seq(4, 5, 6))
+    checkEvaluation(Slice(a0, Literal(-1), Literal(2)), Seq(6))
+    checkExceptionInExpression[RuntimeException](Slice(a0, Literal(1), 
Literal(-1)),
+      "Unexpected value for length")
+    checkExceptionInExpression[RuntimeException](Slice(a0, Literal(0), 
Literal(1)),
+      "Unexpected value for start")
+    checkEvaluation(Slice(a0, Literal(-20), Literal(1)), Seq.empty[Int])
+    checkEvaluation(Slice(a1, Literal(-20), Literal(1)), Seq.empty[String])
+    checkEvaluation(Slice(a0, Literal.create(null, IntegerType), Literal(2)), 
null)
+    checkEvaluation(Slice(a0, Literal(2), Literal.create(null, IntegerType)), 
null)
+    checkEvaluation(Slice(Literal.create(null, ArrayType(IntegerType)), 
Literal(1), Literal(2)),
+      null)
+
+    checkEvaluation(Slice(a1, Literal(1), Literal(2)), Seq("a", "b"))
+    checkEvaluation(Slice(a2, Literal(1), Literal(2)), Seq("", null))
+    checkEvaluation(Slice(a0, Literal(10), Literal(1)), Seq.empty[Int])
+    checkEvaluation(Slice(a1, Literal(10), Literal(1)), Seq.empty[String])
+    checkEvaluation(Slice(a3, Literal(2), Literal(3)), Seq(2, null, 4))
+  }
+
   test("ArrayJoin") {
     def testArrays(
         arrays: Seq[Expression],

http://git-wip-us.apache.org/repos/asf/spark/blob/e35ad3ca/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index b4bf6d7..a22e9d4 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -106,6 +106,12 @@ trait ExpressionEvalHelper extends 
GeneratorDrivenPropertyChecks {
 
   protected def checkExceptionInExpression[T <: Throwable : ClassTag](
       expression: => Expression,
+      expectedErrMsg: String): Unit = {
+    checkExceptionInExpression[T](expression, InternalRow.empty, 
expectedErrMsg)
+  }
+
+  protected def checkExceptionInExpression[T <: Throwable : ClassTag](
+      expression: => Expression,
       inputRow: InternalRow,
       expectedErrMsg: String): Unit = {
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e35ad3ca/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
index 730b36c..77ca640 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
@@ -223,7 +223,6 @@ class ObjectExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
       Literal.fromObject(new java.util.LinkedList[Int]),
       Map("nonexisting" -> Literal(1)))
     checkExceptionInExpression[Exception](initializeWithNonexistingMethod,
-      InternalRow.fromSeq(Seq()),
       """A method named "nonexisting" is not declared in any enclosing class 
""" +
         "nor any supertype")
 

http://git-wip-us.apache.org/repos/asf/spark/blob/e35ad3ca/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 10b6dcc..8f9e4ae 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -3040,6 +3040,16 @@ object functions {
   }
 
   /**
+   * Returns an array containing all the elements in `x` from index `start` 
(or starting from the
+   * end if `start` is negative) with the specified `length`.
+   * @group collection_funcs
+   * @since 2.4.0
+   */
+  def slice(x: Column, start: Int, length: Int): Column = withExpr {
+    Slice(x.expr, Literal(start), Literal(length))
+  }
+
+  /**
    * Concatenates the elements of `column` using the `delimiter`. Null values 
are replaced with
    * `nullReplacement`.
    * @group collection_funcs

http://git-wip-us.apache.org/repos/asf/spark/blob/e35ad3ca/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index ae21cbc..ecce06f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -442,6 +442,22 @@ class DataFrameFunctionsSuite extends QueryTest with 
SharedSQLContext {
     )
   }
 
+  test("slice function") {
+    val df = Seq(
+      Seq(1, 2, 3),
+      Seq(4, 5)
+    ).toDF("x")
+
+    val answer = Seq(Row(Seq(2, 3)), Row(Seq(5)))
+
+    checkAnswer(df.select(slice(df("x"), 2, 2)), answer)
+    checkAnswer(df.selectExpr("slice(x, 2, 2)"), answer)
+
+    val answerNegative = Seq(Row(Seq(3)), Row(Seq(5)))
+    checkAnswer(df.select(slice(df("x"), -1, 1)), answerNegative)
+    checkAnswer(df.selectExpr("slice(x, -1, 1)"), answerNegative)
+  }
+
   test("array_join function") {
     val df = Seq(
       (Seq[String]("a", "b"), ","),


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to