[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user asfgit closed the pull request at: https://github.com/apache/spark/pull/21040 --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user mgaido91 commented on a diff in the pull request: https://github.com/apache/spark/pull/21040#discussion_r186139938 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -378,6 +378,138 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(x, start, length) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ +Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def children: Seq[Expression] = Seq(x, start, length) + + lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { +val startInt = startVal.asInstanceOf[Int] +val lengthInt = lengthVal.asInstanceOf[Int] +val arr = xVal.asInstanceOf[ArrayData] +val startIndex = if (startInt == 0) { + throw new RuntimeException( +s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") +} else if (startInt < 0) { + startInt + arr.numElements() +} else { + startInt - 1 +} +if (lengthInt < 0) { + throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + +"length must be greater than or equal to 0.") +} +// startIndex can be negative if start is negative and its absolute value is greater than the +// number of elements in the array +if (startIndex < 0 || startIndex >= arr.numElements()) { + return new GenericArrayData(Array.empty[AnyRef]) +} +val data = arr.toSeq[AnyRef](elementType) +new GenericArrayData(data.slice(startIndex, startIndex + lengthInt)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { +nullSafeCodeGen(ctx, ev, (x, start, length) => { + val startIdx = ctx.freshName("startIdx") + val resLength = ctx.freshName("resLength") + val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false) + s""" + |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue; + |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue; + |if ($start == 0) { + | throw new RuntimeException("Unexpected value for start in function $prettyName: " + |+ "SQL array indices start at 1."); + |} else if ($start < 0) { + | $startIdx = $start + $x.numElements(); + |} else { + | // arrays in SQL are 1-based instead of 0-based + | $startIdx = $start - 1; + |} + |if ($length < 0) { + | throw new RuntimeException("Unexpected value for length in function $prettyName: " + |+ "length must be greater than or equal to 0."); + |} else if ($length > $x.numElements() - $startIdx) { + | $resLength = $x.numElements() - $startIdx; + |} else { + | $resLength = $length; + |} + |${genCodeForResult(ctx, ev, x, startIdx, resLength)} + """.stripMargin +}) + } + + def genCodeForResult( + ctx: CodegenContext, + ev: ExprCode, + inputArray: String, + startIdx: String, + resLength: String): String = { +val values = ctx.freshName("values") +val i = ctx.freshName("i") +val getValue = CodeGenerator.getValue(inputArray, elementType, s"$i + $startIdx") +if (!CodeGenerator.isPrimitiveType(elementType)) { + val arrayClass = classOf[GenericArrayData].getName + s""" + |Object[] $values; + |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { + | $values = new Object[0]; + |} else { + | $values = new Object[$resLength]; + | for (int $i = 0; $i < $resLength; $i ++) { + |$values[$i] = $getValue; + | } + |} + |${ev.value} =
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user kiszk commented on a diff in the pull request: https://github.com/apache/spark/pull/21040#discussion_r186133103 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -378,6 +378,138 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(x, start, length) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ +Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def children: Seq[Expression] = Seq(x, start, length) + + lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { +val startInt = startVal.asInstanceOf[Int] +val lengthInt = lengthVal.asInstanceOf[Int] +val arr = xVal.asInstanceOf[ArrayData] +val startIndex = if (startInt == 0) { + throw new RuntimeException( +s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") +} else if (startInt < 0) { + startInt + arr.numElements() +} else { + startInt - 1 +} +if (lengthInt < 0) { + throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + +"length must be greater than or equal to 0.") +} +// startIndex can be negative if start is negative and its absolute value is greater than the +// number of elements in the array +if (startIndex < 0 || startIndex >= arr.numElements()) { + return new GenericArrayData(Array.empty[AnyRef]) +} +val data = arr.toSeq[AnyRef](elementType) +new GenericArrayData(data.slice(startIndex, startIndex + lengthInt)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { +nullSafeCodeGen(ctx, ev, (x, start, length) => { + val startIdx = ctx.freshName("startIdx") + val resLength = ctx.freshName("resLength") + val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false) + s""" + |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue; + |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue; + |if ($start == 0) { + | throw new RuntimeException("Unexpected value for start in function $prettyName: " + |+ "SQL array indices start at 1."); + |} else if ($start < 0) { + | $startIdx = $start + $x.numElements(); + |} else { + | // arrays in SQL are 1-based instead of 0-based + | $startIdx = $start - 1; + |} + |if ($length < 0) { + | throw new RuntimeException("Unexpected value for length in function $prettyName: " + |+ "length must be greater than or equal to 0."); + |} else if ($length > $x.numElements() - $startIdx) { + | $resLength = $x.numElements() - $startIdx; + |} else { + | $resLength = $length; + |} + |${genCodeForResult(ctx, ev, x, startIdx, resLength)} + """.stripMargin +}) + } + + def genCodeForResult( + ctx: CodegenContext, + ev: ExprCode, + inputArray: String, + startIdx: String, + resLength: String): String = { +val values = ctx.freshName("values") +val i = ctx.freshName("i") +val getValue = CodeGenerator.getValue(inputArray, elementType, s"$i + $startIdx") +if (!CodeGenerator.isPrimitiveType(elementType)) { + val arrayClass = classOf[GenericArrayData].getName + s""" + |Object[] $values; + |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { + | $values = new Object[0]; + |} else { + | $values = new Object[$resLength]; + | for (int $i = 0; $i < $resLength; $i ++) { + |$values[$i] = $getValue; + | } + |} + |${ev.value} = ne
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user mgaido91 commented on a diff in the pull request: https://github.com/apache/spark/pull/21040#discussion_r186104739 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -378,6 +378,138 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(x, start, length) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ +Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def children: Seq[Expression] = Seq(x, start, length) + + lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { +val startInt = startVal.asInstanceOf[Int] +val lengthInt = lengthVal.asInstanceOf[Int] +val arr = xVal.asInstanceOf[ArrayData] +val startIndex = if (startInt == 0) { + throw new RuntimeException( +s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") +} else if (startInt < 0) { + startInt + arr.numElements() +} else { + startInt - 1 +} +if (lengthInt < 0) { + throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + +"length must be greater than or equal to 0.") +} +// startIndex can be negative if start is negative and its absolute value is greater than the +// number of elements in the array +if (startIndex < 0 || startIndex >= arr.numElements()) { + return new GenericArrayData(Array.empty[AnyRef]) +} +val data = arr.toSeq[AnyRef](elementType) +new GenericArrayData(data.slice(startIndex, startIndex + lengthInt)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { +nullSafeCodeGen(ctx, ev, (x, start, length) => { + val startIdx = ctx.freshName("startIdx") + val resLength = ctx.freshName("resLength") + val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false) + s""" + |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue; + |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue; + |if ($start == 0) { + | throw new RuntimeException("Unexpected value for start in function $prettyName: " + |+ "SQL array indices start at 1."); + |} else if ($start < 0) { + | $startIdx = $start + $x.numElements(); + |} else { + | // arrays in SQL are 1-based instead of 0-based + | $startIdx = $start - 1; + |} + |if ($length < 0) { + | throw new RuntimeException("Unexpected value for length in function $prettyName: " + |+ "length must be greater than or equal to 0."); + |} else if ($length > $x.numElements() - $startIdx) { + | $resLength = $x.numElements() - $startIdx; + |} else { + | $resLength = $length; + |} + |${genCodeForResult(ctx, ev, x, startIdx, resLength)} + """.stripMargin +}) + } + + def genCodeForResult( + ctx: CodegenContext, + ev: ExprCode, + inputArray: String, + startIdx: String, + resLength: String): String = { +val values = ctx.freshName("values") +val i = ctx.freshName("i") +val getValue = CodeGenerator.getValue(inputArray, elementType, s"$i + $startIdx") +if (!CodeGenerator.isPrimitiveType(elementType)) { + val arrayClass = classOf[GenericArrayData].getName + s""" + |Object[] $values; + |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { + | $values = new Object[0]; + |} else { + | $values = new Object[$resLength]; + | for (int $i = 0; $i < $resLength; $i ++) { + |$values[$i] = $getValue; + | } + |} + |${ev.value} =
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user mgaido91 commented on a diff in the pull request: https://github.com/apache/spark/pull/21040#discussion_r186103689 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -378,6 +378,138 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(x, start, length) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ +Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def children: Seq[Expression] = Seq(x, start, length) + + lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { +val startInt = startVal.asInstanceOf[Int] +val lengthInt = lengthVal.asInstanceOf[Int] +val arr = xVal.asInstanceOf[ArrayData] +val startIndex = if (startInt == 0) { + throw new RuntimeException( +s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") +} else if (startInt < 0) { + startInt + arr.numElements() +} else { + startInt - 1 +} +if (lengthInt < 0) { + throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + +"length must be greater than or equal to 0.") +} +// startIndex can be negative if start is negative and its absolute value is greater than the +// number of elements in the array +if (startIndex < 0 || startIndex >= arr.numElements()) { + return new GenericArrayData(Array.empty[AnyRef]) +} +val data = arr.toSeq[AnyRef](elementType) +new GenericArrayData(data.slice(startIndex, startIndex + lengthInt)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { +nullSafeCodeGen(ctx, ev, (x, start, length) => { + val startIdx = ctx.freshName("startIdx") + val resLength = ctx.freshName("resLength") + val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false) + s""" + |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue; + |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue; + |if ($start == 0) { + | throw new RuntimeException("Unexpected value for start in function $prettyName: " + |+ "SQL array indices start at 1."); + |} else if ($start < 0) { + | $startIdx = $start + $x.numElements(); + |} else { + | // arrays in SQL are 1-based instead of 0-based + | $startIdx = $start - 1; + |} + |if ($length < 0) { + | throw new RuntimeException("Unexpected value for length in function $prettyName: " + |+ "length must be greater than or equal to 0."); + |} else if ($length > $x.numElements() - $startIdx) { + | $resLength = $x.numElements() - $startIdx; + |} else { + | $resLength = $length; + |} + |${genCodeForResult(ctx, ev, x, startIdx, resLength)} + """.stripMargin +}) + } + + def genCodeForResult( + ctx: CodegenContext, + ev: ExprCode, + inputArray: String, + startIdx: String, + resLength: String): String = { +val values = ctx.freshName("values") +val i = ctx.freshName("i") +val getValue = CodeGenerator.getValue(inputArray, elementType, s"$i + $startIdx") +if (!CodeGenerator.isPrimitiveType(elementType)) { + val arrayClass = classOf[GenericArrayData].getName + s""" + |Object[] $values; + |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { + | $values = new Object[0]; + |} else { + | $values = new Object[$resLength]; + | for (int $i = 0; $i < $resLength; $i ++) { + |$values[$i] = $getValue; + | } + |} + |${ev.value} =
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user kiszk commented on a diff in the pull request: https://github.com/apache/spark/pull/21040#discussion_r186097518 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -378,6 +378,138 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(x, start, length) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ +Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def children: Seq[Expression] = Seq(x, start, length) + + lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { +val startInt = startVal.asInstanceOf[Int] +val lengthInt = lengthVal.asInstanceOf[Int] +val arr = xVal.asInstanceOf[ArrayData] +val startIndex = if (startInt == 0) { + throw new RuntimeException( +s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") +} else if (startInt < 0) { + startInt + arr.numElements() +} else { + startInt - 1 +} +if (lengthInt < 0) { + throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + +"length must be greater than or equal to 0.") +} +// startIndex can be negative if start is negative and its absolute value is greater than the +// number of elements in the array +if (startIndex < 0 || startIndex >= arr.numElements()) { + return new GenericArrayData(Array.empty[AnyRef]) +} +val data = arr.toSeq[AnyRef](elementType) +new GenericArrayData(data.slice(startIndex, startIndex + lengthInt)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { +nullSafeCodeGen(ctx, ev, (x, start, length) => { + val startIdx = ctx.freshName("startIdx") + val resLength = ctx.freshName("resLength") + val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false) + s""" + |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue; + |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue; + |if ($start == 0) { + | throw new RuntimeException("Unexpected value for start in function $prettyName: " + |+ "SQL array indices start at 1."); + |} else if ($start < 0) { + | $startIdx = $start + $x.numElements(); + |} else { + | // arrays in SQL are 1-based instead of 0-based + | $startIdx = $start - 1; + |} + |if ($length < 0) { + | throw new RuntimeException("Unexpected value for length in function $prettyName: " + |+ "length must be greater than or equal to 0."); + |} else if ($length > $x.numElements() - $startIdx) { + | $resLength = $x.numElements() - $startIdx; + |} else { + | $resLength = $length; + |} + |${genCodeForResult(ctx, ev, x, startIdx, resLength)} + """.stripMargin +}) + } + + def genCodeForResult( + ctx: CodegenContext, + ev: ExprCode, + inputArray: String, + startIdx: String, + resLength: String): String = { +val values = ctx.freshName("values") +val i = ctx.freshName("i") +val getValue = CodeGenerator.getValue(inputArray, elementType, s"$i + $startIdx") +if (!CodeGenerator.isPrimitiveType(elementType)) { + val arrayClass = classOf[GenericArrayData].getName + s""" + |Object[] $values; + |if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) { + | $values = new Object[0]; + |} else { + | $values = new Object[$resLength]; + | for (int $i = 0; $i < $resLength; $i ++) { + |$values[$i] = $getValue; + | } + |} + |${ev.value} = ne
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/21040#discussion_r184914154 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -378,6 +378,138 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(x, start, length) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ +Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def children: Seq[Expression] = Seq(x, start, length) + + lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { +val startInt = startVal.asInstanceOf[Int] +val lengthInt = lengthVal.asInstanceOf[Int] +val arr = xVal.asInstanceOf[ArrayData] +val startIndex = if (startInt == 0) { + throw new RuntimeException( +s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") +} else if (startInt < 0) { + startInt + arr.numElements() +} else { + startInt - 1 +} +if (lengthInt < 0) { + throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + +"length must be greater than or equal to 0.") +} +// startIndex can be negative if start is negative and its absolute value is greater than the +// number of elements in the array +if (startIndex < 0 || startIndex >= arr.numElements()) { + return new GenericArrayData(Array.empty[AnyRef]) +} +val data = arr.toSeq[AnyRef](elementType) +new GenericArrayData(data.slice(startIndex, startIndex + lengthInt)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { +nullSafeCodeGen(ctx, ev, (x, start, length) => { + val startIdx = ctx.freshName("startIdx") + val resLength = ctx.freshName("resLength") + val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false) + s""" + |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue; + |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue; + |if ($start == 0) { + | throw new RuntimeException("Unexpected value for start in function $prettyName: " + |+ "SQL array indices start at 1."); + |} else if ($start < 0) { + | $startIdx = $start + $x.numElements(); + |} else { + | // arrays in SQL are 1-based instead of 0-based + | $startIdx = $start - 1; + |} + |if ($length < 0) { + | throw new RuntimeException("Unexpected value for length in function $prettyName: " + |+ "length must be greater than or equal to 0."); + |} else if ($length > $x.numElements() - $startIdx) { + | $resLength = $x.numElements() - $startIdx; + |} else { + | $resLength = $length; + |} + |${genCodeForResult(ctx, ev, x, startIdx, resLength)} + """.stripMargin +}) + } + + def genCodeForResult( +ctx: CodegenContext, +ev: ExprCode, +inputArray: String, +startIdx: String, +resLength: String): String = { --- End diff -- nit: indent --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user mgaido91 commented on a diff in the pull request: https://github.com/apache/spark/pull/21040#discussion_r184666357 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +287,101 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + + +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a1, a2) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ +Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def children: Seq[Expression] = Seq(x, start, length) + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { +val startInt = startVal.asInstanceOf[Int] +val lengthInt = lengthVal.asInstanceOf[Int] +val arr = xVal.asInstanceOf[ArrayData] +val startIndex = if (startInt == 0) { + throw new RuntimeException( +s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") +} else if (startInt < 0) { + startInt + arr.numElements() +} else { + startInt - 1 +} +if (lengthInt < 0) { + throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + +s"length must be greater than or equal to 0.") +} +// this can happen if start is negative and its absolute value is greater than the +// number of elements in the array +if (startIndex < 0) { + return new GenericArrayData(Array.empty[AnyRef]) +} +val elementType = x.dataType.asInstanceOf[ArrayType].elementType +val data = arr.toArray[AnyRef](elementType) +new GenericArrayData(data.slice(startIndex, startIndex + lengthInt)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { +val elementType = x.dataType.asInstanceOf[ArrayType].elementType +nullSafeCodeGen(ctx, ev, (x, start, length) => { + val arrayClass = classOf[GenericArrayData].getName + val values = ctx.freshName("values") + val i = ctx.freshName("i") + val startIdx = ctx.freshName("startIdx") + val resLength = ctx.freshName("resLength") + val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false) + s""" + |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue; + |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue; + |if ($start == 0) { + | throw new RuntimeException("Unexpected value for start in function $prettyName: " + |+ "SQL array indices start at 1."); + |} else if ($start < 0) { + | $startIdx = $start + $x.numElements(); + |} else { + | // arrays in SQL are 1-based instead of 0-based + | $startIdx = $start - 1; + |} + |if ($length < 0) { + | throw new RuntimeException("Unexpected value for length in function $prettyName: " + |+ "length must be greater than or equal to 0."); + |} else if ($length > $x.numElements() - $startIdx) { + | $resLength = $x.numElements() - $startIdx; + |} else { + | $resLength = $length; + |} + |Object[] $values; + |if ($startIdx < 0) { + | $values = new Object[0]; + |} else { + | $values = new Object[$resLength]; + | for (int $i = 0; $i < $resLength; $i ++) { + |$values[$i] = ${CodeGenerator.getValue(x, elementType, s"$i + $startIdx")}; --- End diff -- You are right, I am not sure why I missed it...maybe I checked outdated code. Sorry, I am fixing it, thanks. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/21040#discussion_r184257279 --- Diff: sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala --- @@ -105,4 +105,28 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + + test("Slice") { +val a0 = Literal.create(Seq(1, 2, 3, 4, 5, 6), ArrayType(IntegerType)) +val a1 = Literal.create(Seq[String]("a", "b", "c", "d"), ArrayType(StringType)) +val a2 = Literal.create(Seq[String]("", null, "a", "b"), ArrayType(StringType)) + +checkEvaluation(Slice(a0, Literal(1), Literal(2)), Seq(1, 2)) +checkEvaluation(Slice(a0, Literal(-3), Literal(2)), Seq(4, 5)) +checkEvaluation(Slice(a0, Literal(4), Literal(10)), Seq(4, 5, 6)) +checkEvaluation(Slice(a0, Literal(-1), Literal(2)), Seq(6)) +checkExceptionInExpression[RuntimeException](Slice(a0, Literal(1), Literal(-1)), + "Unexpected value for length") +checkExceptionInExpression[RuntimeException](Slice(a0, Literal(0), Literal(1)), + "Unexpected value for start") +checkEvaluation(Slice(a0, Literal(-20), Literal(1)), Seq.empty[Int]) +checkEvaluation(Slice(a0, Literal.create(null, IntegerType), Literal(2)), null) +checkEvaluation(Slice(a0, Literal(2), Literal.create(null, IntegerType)), null) +checkEvaluation(Slice(Literal.create(null, ArrayType(IntegerType)), Literal(1), Literal(2)), + null) --- End diff -- And also can you add a case for nullable primitive array like `Slice(Seq(1, 2, null, 4), 2, 3)`? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/21040#discussion_r184257274 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +287,101 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + + +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a1, a2) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ +Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def children: Seq[Expression] = Seq(x, start, length) + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { +val startInt = startVal.asInstanceOf[Int] +val lengthInt = lengthVal.asInstanceOf[Int] +val arr = xVal.asInstanceOf[ArrayData] +val startIndex = if (startInt == 0) { + throw new RuntimeException( +s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") +} else if (startInt < 0) { + startInt + arr.numElements() +} else { + startInt - 1 +} +if (lengthInt < 0) { + throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + +s"length must be greater than or equal to 0.") +} +// this can happen if start is negative and its absolute value is greater than the +// number of elements in the array +if (startIndex < 0) { + return new GenericArrayData(Array.empty[AnyRef]) +} +val elementType = x.dataType.asInstanceOf[ArrayType].elementType +val data = arr.toArray[AnyRef](elementType) +new GenericArrayData(data.slice(startIndex, startIndex + lengthInt)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { +val elementType = x.dataType.asInstanceOf[ArrayType].elementType +nullSafeCodeGen(ctx, ev, (x, start, length) => { + val arrayClass = classOf[GenericArrayData].getName + val values = ctx.freshName("values") + val i = ctx.freshName("i") + val startIdx = ctx.freshName("startIdx") + val resLength = ctx.freshName("resLength") + val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false) + s""" + |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue; + |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue; + |if ($start == 0) { + | throw new RuntimeException("Unexpected value for start in function $prettyName: " + |+ "SQL array indices start at 1."); + |} else if ($start < 0) { + | $startIdx = $start + $x.numElements(); + |} else { + | // arrays in SQL are 1-based instead of 0-based + | $startIdx = $start - 1; + |} + |if ($length < 0) { + | throw new RuntimeException("Unexpected value for length in function $prettyName: " + |+ "length must be greater than or equal to 0."); + |} else if ($length > $x.numElements() - $startIdx) { + | $resLength = $x.numElements() - $startIdx; + |} else { + | $resLength = $length; + |} + |Object[] $values; + |if ($startIdx < 0) { + | $values = new Object[0]; + |} else { + | $values = new Object[$resLength]; + | for (int $i = 0; $i < $resLength; $i ++) { + |$values[$i] = ${CodeGenerator.getValue(x, elementType, s"$i + $startIdx")}; --- End diff -- I might miss something, but seems like `CreateArray` is using different ways to codegen for primitive arrays and the others, and I guess `GenerateSafeProjection` is using `Object[]` on purpose to create `GenericArrayData` to be "safe" (avoid using `UnsafeXxx`). I think we should modify this codegen to avoid boxing. WDYT? Bt
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user mgaido91 commented on a diff in the pull request: https://github.com/apache/spark/pull/21040#discussion_r183100357 --- Diff: sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala --- @@ -105,4 +105,28 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + + test("Slice") { +val a0 = Literal.create(Seq(1, 2, 3, 4, 5, 6), ArrayType(IntegerType)) +val a1 = Literal.create(Seq[String]("a", "b", "c", "d"), ArrayType(StringType)) +val a2 = Literal.create(Seq[String]("", null, "a", "b"), ArrayType(StringType)) + +checkEvaluation(Slice(a0, Literal(1), Literal(2)), Seq(1, 2)) +checkEvaluation(Slice(a0, Literal(-3), Literal(2)), Seq(4, 5)) +checkEvaluation(Slice(a0, Literal(4), Literal(10)), Seq(4, 5, 6)) +checkEvaluation(Slice(a0, Literal(-1), Literal(2)), Seq(6)) +checkExceptionInExpression[RuntimeException](Slice(a0, Literal(1), Literal(-1)), + "Unexpected value for length") +checkExceptionInExpression[RuntimeException](Slice(a0, Literal(0), Literal(1)), + "Unexpected value for start") +checkEvaluation(Slice(a0, Literal(-20), Literal(1)), Seq.empty[Int]) +checkEvaluation(Slice(a0, Literal.create(null, IntegerType), Literal(2)), null) +checkEvaluation(Slice(a0, Literal(2), Literal.create(null, IntegerType)), null) +checkEvaluation(Slice(Literal.create(null, ArrayType(IntegerType)), Literal(1), Literal(2)), + null) --- End diff -- added --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/21040#discussion_r182703273 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +287,101 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + + +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a1, a2) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ +Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def children: Seq[Expression] = Seq(x, start, length) + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { +val startInt = startVal.asInstanceOf[Int] +val lengthInt = lengthVal.asInstanceOf[Int] +val arr = xVal.asInstanceOf[ArrayData] +val startIndex = if (startInt == 0) { + throw new RuntimeException( +s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") --- End diff -- nit: remove an extra space between `$prettyName:` and `SQL`. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/21040#discussion_r182706982 --- Diff: sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala --- @@ -102,6 +102,12 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } } + protected def checkExceptionInExpression[T <: Throwable : ClassTag]( + expression: Expression, --- End diff -- `expression: => Expression` to be consistent with the overloaded one, just in case? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/21040#discussion_r182701319 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +287,101 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + + +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a1, a2) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ +Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) --- End diff -- We don't need `nullable` and `foldable` here because these are the same as defined in `TernaryExpression`. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/21040#discussion_r182701643 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +287,101 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + + +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a1, a2) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ +Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def children: Seq[Expression] = Seq(x, start, length) + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { +val startInt = startVal.asInstanceOf[Int] +val lengthInt = lengthVal.asInstanceOf[Int] +val arr = xVal.asInstanceOf[ArrayData] +val startIndex = if (startInt == 0) { + throw new RuntimeException( +s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") +} else if (startInt < 0) { + startInt + arr.numElements() +} else { + startInt - 1 +} +if (lengthInt < 0) { + throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + +s"length must be greater than or equal to 0.") --- End diff -- nit: unnecessary `s`. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/21040#discussion_r182701040 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +287,101 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + + +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a1, a2) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", --- End diff -- `_FUNC_(x, start, length)` instead of `_FUNC_(a1, a2)`? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/21040#discussion_r182702693 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +287,101 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + + +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a1, a2) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ +Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def children: Seq[Expression] = Seq(x, start, length) + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { +val startInt = startVal.asInstanceOf[Int] +val lengthInt = lengthVal.asInstanceOf[Int] +val arr = xVal.asInstanceOf[ArrayData] +val startIndex = if (startInt == 0) { + throw new RuntimeException( +s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") +} else if (startInt < 0) { + startInt + arr.numElements() +} else { + startInt - 1 +} +if (lengthInt < 0) { + throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + +s"length must be greater than or equal to 0.") +} +// this can happen if start is negative and its absolute value is greater than the +// number of elements in the array +if (startIndex < 0) { --- End diff -- We should also skip when `startIndex >= arr.numElements()` to avoid unnecessary convert `arr.toArray`? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/21040#discussion_r182705280 --- Diff: sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala --- @@ -105,4 +105,28 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(ArrayContains(a3, Literal("")), null) checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null) } + + test("Slice") { +val a0 = Literal.create(Seq(1, 2, 3, 4, 5, 6), ArrayType(IntegerType)) +val a1 = Literal.create(Seq[String]("a", "b", "c", "d"), ArrayType(StringType)) +val a2 = Literal.create(Seq[String]("", null, "a", "b"), ArrayType(StringType)) + +checkEvaluation(Slice(a0, Literal(1), Literal(2)), Seq(1, 2)) +checkEvaluation(Slice(a0, Literal(-3), Literal(2)), Seq(4, 5)) +checkEvaluation(Slice(a0, Literal(4), Literal(10)), Seq(4, 5, 6)) +checkEvaluation(Slice(a0, Literal(-1), Literal(2)), Seq(6)) +checkExceptionInExpression[RuntimeException](Slice(a0, Literal(1), Literal(-1)), + "Unexpected value for length") +checkExceptionInExpression[RuntimeException](Slice(a0, Literal(0), Literal(1)), + "Unexpected value for start") +checkEvaluation(Slice(a0, Literal(-20), Literal(1)), Seq.empty[Int]) +checkEvaluation(Slice(a0, Literal.create(null, IntegerType), Literal(2)), null) +checkEvaluation(Slice(a0, Literal(2), Literal.create(null, IntegerType)), null) +checkEvaluation(Slice(Literal.create(null, ArrayType(IntegerType)), Literal(1), Literal(2)), + null) --- End diff -- Can you add a case for something like `Slice(a0, Literal(10), Literal(1))`? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user mgaido91 commented on a diff in the pull request: https://github.com/apache/spark/pull/21040#discussion_r181338128 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +287,101 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + + +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a1, a2) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ +Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def children: Seq[Expression] = Seq(x, start, length) + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { +val startInt = startVal.asInstanceOf[Int] +val lengthInt = lengthVal.asInstanceOf[Int] +val arr = xVal.asInstanceOf[ArrayData] +val startIndex = if (startInt == 0) { + throw new RuntimeException( +s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") +} else if (startInt < 0) { + startInt + arr.numElements() +} else { + startInt - 1 +} +if (lengthInt < 0) { + throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + +s"length must be greater than or equal to 0.") +} +// this can happen if start is negative and its absolute value is greater than the +// number of elements in the array +if (startIndex < 0) { + return new GenericArrayData(Array.empty[AnyRef]) +} +val elementType = x.dataType.asInstanceOf[ArrayType].elementType +val data = arr.toArray[AnyRef](elementType) +new GenericArrayData(data.slice(startIndex, startIndex + lengthInt)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { +val elementType = x.dataType.asInstanceOf[ArrayType].elementType +nullSafeCodeGen(ctx, ev, (x, start, length) => { + val arrayClass = classOf[GenericArrayData].getName + val values = ctx.freshName("values") + val i = ctx.freshName("i") + val startIdx = ctx.freshName("startIdx") + val resLength = ctx.freshName("resLength") + val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false) + s""" + |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue; + |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue; + |if ($start == 0) { + | throw new RuntimeException("Unexpected value for start in function $prettyName: " + |+ "SQL array indices start at 1."); + |} else if ($start < 0) { + | $startIdx = $start + $x.numElements(); + |} else { + | // arrays in SQL are 1-based instead of 0-based + | $startIdx = $start - 1; + |} + |if ($length < 0) { + | throw new RuntimeException("Unexpected value for length in function $prettyName: " + |+ "length must be greater than or equal to 0."); + |} else if ($length > $x.numElements() - $startIdx) { + | $resLength = $x.numElements() - $startIdx; + |} else { + | $resLength = $length; + |} + |Object[] $values; + |if ($startIdx < 0) { + | $values = new Object[0]; + |} else { + | $values = new Object[$resLength]; + | for (int $i = 0; $i < $resLength; $i ++) { + |$values[$i] = ${CodeGenerator.getValue(x, elementType, s"$i + $startIdx")}; --- End diff -- My target of coherency was the `CreateArray` operator and the code generated in `GenerateSafeProjection`. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user kiszk commented on a diff in the pull request: https://github.com/apache/spark/pull/21040#discussion_r181313290 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +287,101 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + + +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a1, a2) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ +Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def children: Seq[Expression] = Seq(x, start, length) + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { +val startInt = startVal.asInstanceOf[Int] +val lengthInt = lengthVal.asInstanceOf[Int] +val arr = xVal.asInstanceOf[ArrayData] +val startIndex = if (startInt == 0) { + throw new RuntimeException( +s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") +} else if (startInt < 0) { + startInt + arr.numElements() +} else { + startInt - 1 +} +if (lengthInt < 0) { + throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + +s"length must be greater than or equal to 0.") +} +// this can happen if start is negative and its absolute value is greater than the +// number of elements in the array +if (startIndex < 0) { + return new GenericArrayData(Array.empty[AnyRef]) +} +val elementType = x.dataType.asInstanceOf[ArrayType].elementType +val data = arr.toArray[AnyRef](elementType) +new GenericArrayData(data.slice(startIndex, startIndex + lengthInt)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { +val elementType = x.dataType.asInstanceOf[ArrayType].elementType +nullSafeCodeGen(ctx, ev, (x, start, length) => { + val arrayClass = classOf[GenericArrayData].getName + val values = ctx.freshName("values") + val i = ctx.freshName("i") + val startIdx = ctx.freshName("startIdx") + val resLength = ctx.freshName("resLength") + val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false) + s""" + |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue; + |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue; + |if ($start == 0) { + | throw new RuntimeException("Unexpected value for start in function $prettyName: " + |+ "SQL array indices start at 1."); + |} else if ($start < 0) { + | $startIdx = $start + $x.numElements(); + |} else { + | // arrays in SQL are 1-based instead of 0-based + | $startIdx = $start - 1; + |} + |if ($length < 0) { + | throw new RuntimeException("Unexpected value for length in function $prettyName: " + |+ "length must be greater than or equal to 0."); + |} else if ($length > $x.numElements() - $startIdx) { + | $resLength = $x.numElements() - $startIdx; + |} else { + | $resLength = $length; + |} + |Object[] $values; + |if ($startIdx < 0) { + | $values = new Object[0]; + |} else { + | $values = new Object[$resLength]; + | for (int $i = 0; $i < $resLength; $i ++) { + |$values[$i] = ${CodeGenerator.getValue(x, elementType, s"$i + $startIdx")}; --- End diff -- For the future, I agree that this is the right way to generate Java code since we can avoid boxing. On the other hand, you are proposing to postpone specialization. In `eval` and generated code, `GenericArrayData` is generated by using `Object[]`. I may misunderstand `for coherency` since I may not find the target of the coh
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user viirya commented on a diff in the pull request: https://github.com/apache/spark/pull/21040#discussion_r181065461 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +287,101 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + + +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a1, a2) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ +Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def children: Seq[Expression] = Seq(x, start, length) + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { +val startInt = startVal.asInstanceOf[Int] +val lengthInt = lengthVal.asInstanceOf[Int] +val arr = xVal.asInstanceOf[ArrayData] +val startIndex = if (startInt == 0) { + throw new RuntimeException( +s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") +} else if (startInt < 0) { + startInt + arr.numElements() +} else { + startInt - 1 +} +if (lengthInt < 0) { + throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + +s"length must be greater than or equal to 0.") +} +// this can happen if start is negative and its absolute value is greater than the +// number of elements in the array +if (startIndex < 0) { + return new GenericArrayData(Array.empty[AnyRef]) +} +val elementType = x.dataType.asInstanceOf[ArrayType].elementType +val data = arr.toArray[AnyRef](elementType) --- End diff -- I think #20984 should be merged soon. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user kiszk commented on a diff in the pull request: https://github.com/apache/spark/pull/21040#discussion_r181045317 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +287,101 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + + +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a1, a2) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ +Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def children: Seq[Expression] = Seq(x, start, length) + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { +val startInt = startVal.asInstanceOf[Int] +val lengthInt = lengthVal.asInstanceOf[Int] +val arr = xVal.asInstanceOf[ArrayData] +val startIndex = if (startInt == 0) { + throw new RuntimeException( +s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") +} else if (startInt < 0) { + startInt + arr.numElements() +} else { + startInt - 1 +} +if (lengthInt < 0) { + throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + +s"length must be greater than or equal to 0.") +} +// this can happen if start is negative and its absolute value is greater than the +// number of elements in the array +if (startIndex < 0) { + return new GenericArrayData(Array.empty[AnyRef]) +} +val elementType = x.dataType.asInstanceOf[ArrayType].elementType +val data = arr.toArray[AnyRef](elementType) --- End diff -- I think it would be good since we can avoid the whole array copy if that PR will be merged near future. @viirya What do you think? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user mgaido91 commented on a diff in the pull request: https://github.com/apache/spark/pull/21040#discussion_r180997886 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +287,101 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + + +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a1, a2) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ +Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def children: Seq[Expression] = Seq(x, start, length) + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { +val startInt = startVal.asInstanceOf[Int] +val lengthInt = lengthVal.asInstanceOf[Int] +val arr = xVal.asInstanceOf[ArrayData] +val startIndex = if (startInt == 0) { + throw new RuntimeException( +s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") +} else if (startInt < 0) { + startInt + arr.numElements() +} else { + startInt - 1 +} +if (lengthInt < 0) { + throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + +s"length must be greater than or equal to 0.") +} +// this can happen if start is negative and its absolute value is greater than the +// number of elements in the array +if (startIndex < 0) { + return new GenericArrayData(Array.empty[AnyRef]) +} +val elementType = x.dataType.asInstanceOf[ArrayType].elementType +val data = arr.toArray[AnyRef](elementType) --- End diff -- shall we wait for that PR to get in? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user mgaido91 commented on a diff in the pull request: https://github.com/apache/spark/pull/21040#discussion_r180997042 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +287,101 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + + +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a1, a2) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ +Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def children: Seq[Expression] = Seq(x, start, length) + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { +val startInt = startVal.asInstanceOf[Int] +val lengthInt = lengthVal.asInstanceOf[Int] +val arr = xVal.asInstanceOf[ArrayData] +val startIndex = if (startInt == 0) { + throw new RuntimeException( +s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") +} else if (startInt < 0) { + startInt + arr.numElements() +} else { + startInt - 1 +} +if (lengthInt < 0) { + throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + +s"length must be greater than or equal to 0.") +} +// this can happen if start is negative and its absolute value is greater than the +// number of elements in the array +if (startIndex < 0) { + return new GenericArrayData(Array.empty[AnyRef]) +} +val elementType = x.dataType.asInstanceOf[ArrayType].elementType +val data = arr.toArray[AnyRef](elementType) +new GenericArrayData(data.slice(startIndex, startIndex + lengthInt)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { +val elementType = x.dataType.asInstanceOf[ArrayType].elementType +nullSafeCodeGen(ctx, ev, (x, start, length) => { + val arrayClass = classOf[GenericArrayData].getName + val values = ctx.freshName("values") + val i = ctx.freshName("i") + val startIdx = ctx.freshName("startIdx") + val resLength = ctx.freshName("resLength") + val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false) + s""" + |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue; + |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue; + |if ($start == 0) { + | throw new RuntimeException("Unexpected value for start in function $prettyName: " + |+ "SQL array indices start at 1."); + |} else if ($start < 0) { + | $startIdx = $start + $x.numElements(); + |} else { + | // arrays in SQL are 1-based instead of 0-based + | $startIdx = $start - 1; + |} + |if ($length < 0) { + | throw new RuntimeException("Unexpected value for length in function $prettyName: " + |+ "length must be greater than or equal to 0."); + |} else if ($length > $x.numElements() - $startIdx) { + | $resLength = $x.numElements() - $startIdx; + |} else { + | $resLength = $length; + |} + |Object[] $values; + |if ($startIdx < 0) { + | $values = new Object[0]; + |} else { + | $values = new Object[$resLength]; + | for (int $i = 0; $i < $resLength; $i ++) { + |$values[$i] = ${CodeGenerator.getValue(x, elementType, s"$i + $startIdx")}; --- End diff -- I think it can be helpful. Moreover, this is the way also `CreateArray` and `GenerateSafeProjection` work, so for coherency I think this is the right thing to do. What do you think? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional co
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user kiszk commented on a diff in the pull request: https://github.com/apache/spark/pull/21040#discussion_r180834827 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +287,101 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + + +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a1, a2) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ +Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def children: Seq[Expression] = Seq(x, start, length) + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { +val startInt = startVal.asInstanceOf[Int] +val lengthInt = lengthVal.asInstanceOf[Int] +val arr = xVal.asInstanceOf[ArrayData] +val startIndex = if (startInt == 0) { + throw new RuntimeException( +s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") +} else if (startInt < 0) { + startInt + arr.numElements() +} else { + startInt - 1 +} +if (lengthInt < 0) { + throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + +s"length must be greater than or equal to 0.") +} +// this can happen if start is negative and its absolute value is greater than the +// number of elements in the array +if (startIndex < 0) { + return new GenericArrayData(Array.empty[AnyRef]) +} +val elementType = x.dataType.asInstanceOf[ArrayType].elementType +val data = arr.toArray[AnyRef](elementType) +new GenericArrayData(data.slice(startIndex, startIndex + lengthInt)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { +val elementType = x.dataType.asInstanceOf[ArrayType].elementType +nullSafeCodeGen(ctx, ev, (x, start, length) => { + val arrayClass = classOf[GenericArrayData].getName + val values = ctx.freshName("values") + val i = ctx.freshName("i") + val startIdx = ctx.freshName("startIdx") + val resLength = ctx.freshName("resLength") + val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false) + s""" + |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue; + |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue; + |if ($start == 0) { + | throw new RuntimeException("Unexpected value for start in function $prettyName: " + |+ "SQL array indices start at 1."); + |} else if ($start < 0) { + | $startIdx = $start + $x.numElements(); + |} else { + | // arrays in SQL are 1-based instead of 0-based + | $startIdx = $start - 1; + |} + |if ($length < 0) { + | throw new RuntimeException("Unexpected value for length in function $prettyName: " + |+ "length must be greater than or equal to 0."); + |} else if ($length > $x.numElements() - $startIdx) { + | $resLength = $x.numElements() - $startIdx; + |} else { + | $resLength = $length; + |} + |Object[] $values; + |if ($startIdx < 0) { + | $values = new Object[0]; + |} else { + | $values = new Object[$resLength]; + | for (int $i = 0; $i < $resLength; $i ++) { + |$values[$i] = ${CodeGenerator.getValue(x, elementType, s"$i + $startIdx")}; --- End diff -- I see. If we postpone specialization, is it necessary to generate Java code for now? The generated code seems to do the same thing in `nullSafeEval`. WDYT? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@s
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user kiszk commented on a diff in the pull request: https://github.com/apache/spark/pull/21040#discussion_r180834135 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +287,101 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + + +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a1, a2) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ +Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def children: Seq[Expression] = Seq(x, start, length) + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { +val startInt = startVal.asInstanceOf[Int] +val lengthInt = lengthVal.asInstanceOf[Int] +val arr = xVal.asInstanceOf[ArrayData] +val startIndex = if (startInt == 0) { + throw new RuntimeException( +s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") +} else if (startInt < 0) { + startInt + arr.numElements() +} else { + startInt - 1 +} +if (lengthInt < 0) { + throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + +s"length must be greater than or equal to 0.") +} +// this can happen if start is negative and its absolute value is greater than the +// number of elements in the array +if (startIndex < 0) { + return new GenericArrayData(Array.empty[AnyRef]) +} +val elementType = x.dataType.asInstanceOf[ArrayType].elementType +val data = arr.toArray[AnyRef](elementType) --- End diff -- This PR https://github.com/apache/spark/pull/20984 can make `slice` better. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user mgaido91 commented on a diff in the pull request: https://github.com/apache/spark/pull/21040#discussion_r180773760 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +287,101 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + + +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a1, a2) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ +Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def children: Seq[Expression] = Seq(x, start, length) + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { +val startInt = startVal.asInstanceOf[Int] +val lengthInt = lengthVal.asInstanceOf[Int] +val arr = xVal.asInstanceOf[ArrayData] +val startIndex = if (startInt == 0) { + throw new RuntimeException( +s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") +} else if (startInt < 0) { + startInt + arr.numElements() +} else { + startInt - 1 +} +if (lengthInt < 0) { + throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + +s"length must be greater than or equal to 0.") +} +// this can happen if start is negative and its absolute value is greater than the +// number of elements in the array +if (startIndex < 0) { + return new GenericArrayData(Array.empty[AnyRef]) +} +val elementType = x.dataType.asInstanceOf[ArrayType].elementType +val data = arr.toArray[AnyRef](elementType) +new GenericArrayData(data.slice(startIndex, startIndex + lengthInt)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { +val elementType = x.dataType.asInstanceOf[ArrayType].elementType +nullSafeCodeGen(ctx, ev, (x, start, length) => { + val arrayClass = classOf[GenericArrayData].getName + val values = ctx.freshName("values") + val i = ctx.freshName("i") + val startIdx = ctx.freshName("startIdx") + val resLength = ctx.freshName("resLength") + val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false) + s""" + |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue; + |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue; + |if ($start == 0) { + | throw new RuntimeException("Unexpected value for start in function $prettyName: " + |+ "SQL array indices start at 1."); + |} else if ($start < 0) { + | $startIdx = $start + $x.numElements(); + |} else { + | // arrays in SQL are 1-based instead of 0-based + | $startIdx = $start - 1; + |} + |if ($length < 0) { + | throw new RuntimeException("Unexpected value for length in function $prettyName: " + |+ "length must be greater than or equal to 0."); + |} else if ($length > $x.numElements() - $startIdx) { + | $resLength = $x.numElements() - $startIdx; + |} else { + | $resLength = $length; + |} + |Object[] $values; + |if ($startIdx < 0) { + | $values = new Object[0]; + |} else { + | $values = new Object[$resLength]; + | for (int $i = 0; $i < $resLength; $i ++) { + |$values[$i] = ${CodeGenerator.getValue(x, elementType, s"$i + $startIdx")}; --- End diff -- I though about that too, but I am not sure there is a better solution: this approach is used both in `CreateArray` and `GenerateSafeProjection`. And there is a TODO for specialized versions of `GenericArrayData` able to deal with primitive types without boxing. Probably we can try and fix this TODO in another PR/JIRA. What do
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user kiszk commented on a diff in the pull request: https://github.com/apache/spark/pull/21040#discussion_r180770274 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +287,101 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + + +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a1, a2) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ +Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def children: Seq[Expression] = Seq(x, start, length) + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { +val startInt = startVal.asInstanceOf[Int] +val lengthInt = lengthVal.asInstanceOf[Int] +val arr = xVal.asInstanceOf[ArrayData] +val startIndex = if (startInt == 0) { + throw new RuntimeException( +s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") +} else if (startInt < 0) { + startInt + arr.numElements() +} else { + startInt - 1 +} +if (lengthInt < 0) { + throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + +s"length must be greater than or equal to 0.") +} +// this can happen if start is negative and its absolute value is greater than the +// number of elements in the array +if (startIndex < 0) { + return new GenericArrayData(Array.empty[AnyRef]) +} +val elementType = x.dataType.asInstanceOf[ArrayType].elementType +val data = arr.toArray[AnyRef](elementType) +new GenericArrayData(data.slice(startIndex, startIndex + lengthInt)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { +val elementType = x.dataType.asInstanceOf[ArrayType].elementType +nullSafeCodeGen(ctx, ev, (x, start, length) => { + val arrayClass = classOf[GenericArrayData].getName + val values = ctx.freshName("values") + val i = ctx.freshName("i") + val startIdx = ctx.freshName("startIdx") + val resLength = ctx.freshName("resLength") + val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false) + s""" + |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue; + |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue; + |if ($start == 0) { + | throw new RuntimeException("Unexpected value for start in function $prettyName: " + |+ "SQL array indices start at 1."); + |} else if ($start < 0) { + | $startIdx = $start + $x.numElements(); + |} else { + | // arrays in SQL are 1-based instead of 0-based + | $startIdx = $start - 1; + |} + |if ($length < 0) { + | throw new RuntimeException("Unexpected value for length in function $prettyName: " + |+ "length must be greater than or equal to 0."); + |} else if ($length > $x.numElements() - $startIdx) { + | $resLength = $x.numElements() - $startIdx; + |} else { + | $resLength = $length; + |} + |Object[] $values; + |if ($startIdx < 0) { + | $values = new Object[0]; + |} else { + | $values = new Object[$resLength]; + | for (int $i = 0; $i < $resLength; $i ++) { + |$values[$i] = ${CodeGenerator.getValue(x, elementType, s"$i + $startIdx")}; --- End diff -- May this assignment cause performance degradation due to boxing if array element type is primitive (e.g. `float`)? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
Github user kiszk commented on a diff in the pull request: https://github.com/apache/spark/pull/21040#discussion_r180766152 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -287,3 +287,101 @@ case class ArrayContains(left: Expression, right: Expression) override def prettyName: String = "array_contains" } + + +/** + * Slices an array according to the requested start index and length + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a1, a2) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.", + examples = """ +Examples: + > SELECT _FUNC_(array(1, 2, 3, 4), 2, 2); + [2,3] + > SELECT _FUNC_(array(1, 2, 3, 4), -2, 2); + [3,4] + """, since = "2.4.0") +// scalastyle:on line.size.limit +case class Slice(x: Expression, start: Expression, length: Expression) + extends TernaryExpression with ImplicitCastInputTypes { + + override def dataType: DataType = x.dataType + + override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType) + + override def nullable: Boolean = children.exists(_.nullable) + + override def foldable: Boolean = children.forall(_.foldable) + + override def children: Seq[Expression] = Seq(x, start, length) + + override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = { +val startInt = startVal.asInstanceOf[Int] +val lengthInt = lengthVal.asInstanceOf[Int] +val arr = xVal.asInstanceOf[ArrayData] +val startIndex = if (startInt == 0) { + throw new RuntimeException( +s"Unexpected value for start in function $prettyName: SQL array indices start at 1.") +} else if (startInt < 0) { + startInt + arr.numElements() +} else { + startInt - 1 +} +if (lengthInt < 0) { + throw new RuntimeException(s"Unexpected value for length in function $prettyName: " + +s"length must be greater than or equal to 0.") +} +// this can happen if start is negative and its absolute value is greater than the +// number of elements in the array +if (startIndex < 0) { + return new GenericArrayData(Array.empty[AnyRef]) +} +val elementType = x.dataType.asInstanceOf[ArrayType].elementType +val data = arr.toArray[AnyRef](elementType) +new GenericArrayData(data.slice(startIndex, startIndex + lengthInt)) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { +val elementType = x.dataType.asInstanceOf[ArrayType].elementType +nullSafeCodeGen(ctx, ev, (x, start, length) => { + val arrayClass = classOf[GenericArrayData].getName + val values = ctx.freshName("values") + val i = ctx.freshName("i") + val startIdx = ctx.freshName("startIdx") + val resLength = ctx.freshName("resLength") + val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false) + s""" + |${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue; + |${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue; + |if ($start == 0) { + | throw new RuntimeException("Unexpected value for start in function $prettyName: " + |+ "SQL array indices start at 1."); + |} else if ($start < 0) { + | $startIdx = $start + $x.numElements(); + |} else { + | // arrays in SQL are 1-based instead of 0-based + | $startIdx = $start - 1; + |} + |if ($length < 0) { --- End diff -- `$lengthInt` ? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #21040: [SPARK-23930][SQL] Add slice function
GitHub user mgaido91 opened a pull request: https://github.com/apache/spark/pull/21040 [SPARK-23930][SQL] Add slice function ## What changes were proposed in this pull request? The PR add the `slice` function. The behavior of the function is based on Presto's one. The function slices an array according to the requested start index and length. ## How was this patch tested? added UTs You can merge this pull request into a Git repository by running: $ git pull https://github.com/mgaido91/spark SPARK-23930 Alternatively you can review and apply these changes as the patch at: https://github.com/apache/spark/pull/21040.patch To close this pull request, make a commit to your master/trunk branch with (at least) the following in the commit message: This closes #21040 commit 5cbbf7afb164d090bfe5730380a2fbe0a18146c2 Author: Marco Gaido Date: 2018-04-10T13:49:53Z [SPARK-23930][SQL] Add slice function --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org