Repository: spark Updated Branches: refs/heads/branch-2.4 e42546259 -> 604828eda
[SPARK-25469][SQL] Eval methods of Concat, Reverse and ElementAt should use pattern matching only once ## What changes were proposed in this pull request? The PR proposes to avoid usage of pattern matching for each call of ```eval``` method within: - ```Concat``` - ```Reverse``` - ```ElementAt``` ## How was this patch tested? Run the existing tests for ```Concat```, ```Reverse``` and ```ElementAt``` expression classes. Closes #22471 from mn-mikke/SPARK-25470. Authored-by: Marek Novotny <mn.mi...@gmail.com> Signed-off-by: Takeshi Yamamuro <yamam...@apache.org> (cherry picked from commit 2c9d8f56c71093faf152ca7136c5fcc4a7b2a95f) Signed-off-by: Takeshi Yamamuro <yamam...@apache.org> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/604828ed Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/604828ed Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/604828ed Branch: refs/heads/branch-2.4 Commit: 604828eda0930b933be39d5db7bdb1b29d499f32 Parents: e425462 Author: Marek Novotny <mn.mi...@gmail.com> Authored: Fri Sep 21 18:16:54 2018 +0900 Committer: Takeshi Yamamuro <yamam...@apache.org> Committed: Fri Sep 21 18:30:32 2018 +0900 ---------------------------------------------------------------------- .../expressions/collectionOperations.scala | 81 ++++++++++++-------- 1 file changed, 48 insertions(+), 33 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/604828ed/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e23ebef..161adc9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -1268,11 +1268,15 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI override def dataType: DataType = child.dataType - @transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType + override def nullSafeEval(input: Any): Any = doReverse(input) - override def nullSafeEval(input: Any): Any = input match { - case a: ArrayData => new GenericArrayData(a.toObjectArray(elementType).reverse) - case s: UTF8String => s.reverse() + @transient private lazy val doReverse: Any => Any = dataType match { + case ArrayType(elementType, _) => + input => { + val arrayData = input.asInstanceOf[ArrayData] + new GenericArrayData(arrayData.toObjectArray(elementType).reverse) + } + case StringType => _.asInstanceOf[UTF8String].reverse() } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -1294,6 +1298,7 @@ case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastI val i = ctx.freshName("i") val j = ctx.freshName("j") + val elementType = dataType.asInstanceOf[ArrayType].elementType val initialization = CodeGenerator.createArrayData( arrayData, elementType, numElements, s" $prettyName failed.") val assignment = CodeGenerator.createArrayAssignment( @@ -2164,9 +2169,11 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti override def nullable: Boolean = true - override def nullSafeEval(value: Any, ordinal: Any): Any = { - left.dataType match { - case _: ArrayType => + override def nullSafeEval(value: Any, ordinal: Any): Any = doElementAt(value, ordinal) + + @transient private lazy val doElementAt: (Any, Any) => Any = left.dataType match { + case _: ArrayType => + (value, ordinal) => { val array = value.asInstanceOf[ArrayData] val index = ordinal.asInstanceOf[Int] if (array.numElements() < math.abs(index)) { @@ -2185,9 +2192,9 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti array.get(idx, dataType) } } - case _: MapType => - getValueEval(value, ordinal, mapKeyType, ordering) - } + } + case _: MapType => + (value, ordinal) => getValueEval(value, ordinal, mapKeyType, ordering) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -2278,33 +2285,41 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio override def foldable: Boolean = children.forall(_.foldable) - override def eval(input: InternalRow): Any = dataType match { + override def eval(input: InternalRow): Any = doConcat(input) + + @transient private lazy val doConcat: InternalRow => Any = dataType match { case BinaryType => - val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) - ByteArray.concat(inputs: _*) + input => { + val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]]) + ByteArray.concat(inputs: _*) + } case StringType => - val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) - UTF8String.concat(inputs : _*) + input => { + val inputs = children.map(_.eval(input).asInstanceOf[UTF8String]) + UTF8String.concat(inputs: _*) + } case ArrayType(elementType, _) => - val inputs = children.toStream.map(_.eval(input)) - if (inputs.contains(null)) { - null - } else { - val arrayData = inputs.map(_.asInstanceOf[ArrayData]) - val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements()) - if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" + - " elements due to exceeding the array size limit " + - ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".") - } - val finalData = new Array[AnyRef](numberOfElements.toInt) - var position = 0 - for(ad <- arrayData) { - val arr = ad.toObjectArray(elementType) - Array.copy(arr, 0, finalData, position, arr.length) - position += arr.length + input => { + val inputs = children.toStream.map(_.eval(input)) + if (inputs.contains(null)) { + null + } else { + val arrayData = inputs.map(_.asInstanceOf[ArrayData]) + val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements()) + if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" + + " elements due to exceeding the array size limit " + + ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".") + } + val finalData = new Array[AnyRef](numberOfElements.toInt) + var position = 0 + for (ad <- arrayData) { + val arr = ad.toObjectArray(elementType) + Array.copy(arr, 0, finalData, position, arr.length) + position += arr.length + } + new GenericArrayData(finalData) } - new GenericArrayData(finalData) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org