Repository: spark
Updated Branches:
  refs/heads/branch-2.4 e42546259 -> 604828eda


[SPARK-25469][SQL] Eval methods of Concat, Reverse and ElementAt should use 
pattern matching only once

## What changes were proposed in this pull request?

The PR proposes to avoid usage of pattern matching for each call of ```eval``` 
method within:
- ```Concat```
- ```Reverse```
- ```ElementAt```

## How was this patch tested?

Run the existing tests for ```Concat```, ```Reverse``` and  ```ElementAt``` 
expression classes.

Closes #22471 from mn-mikke/SPARK-25470.

Authored-by: Marek Novotny <mn.mi...@gmail.com>
Signed-off-by: Takeshi Yamamuro <yamam...@apache.org>
(cherry picked from commit 2c9d8f56c71093faf152ca7136c5fcc4a7b2a95f)
Signed-off-by: Takeshi Yamamuro <yamam...@apache.org>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/604828ed
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/604828ed
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/604828ed

Branch: refs/heads/branch-2.4
Commit: 604828eda0930b933be39d5db7bdb1b29d499f32
Parents: e425462
Author: Marek Novotny <mn.mi...@gmail.com>
Authored: Fri Sep 21 18:16:54 2018 +0900
Committer: Takeshi Yamamuro <yamam...@apache.org>
Committed: Fri Sep 21 18:30:32 2018 +0900

----------------------------------------------------------------------
 .../expressions/collectionOperations.scala      | 81 ++++++++++++--------
 1 file changed, 48 insertions(+), 33 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/604828ed/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index e23ebef..161adc9 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -1268,11 +1268,15 @@ case class Reverse(child: Expression) extends 
UnaryExpression with ImplicitCastI
 
   override def dataType: DataType = child.dataType
 
-  @transient private lazy val elementType: DataType = 
dataType.asInstanceOf[ArrayType].elementType
+  override def nullSafeEval(input: Any): Any = doReverse(input)
 
-  override def nullSafeEval(input: Any): Any = input match {
-    case a: ArrayData => new 
GenericArrayData(a.toObjectArray(elementType).reverse)
-    case s: UTF8String => s.reverse()
+  @transient private lazy val doReverse: Any => Any = dataType match {
+    case ArrayType(elementType, _) =>
+      input => {
+        val arrayData = input.asInstanceOf[ArrayData]
+        new GenericArrayData(arrayData.toObjectArray(elementType).reverse)
+      }
+    case StringType => _.asInstanceOf[UTF8String].reverse()
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@@ -1294,6 +1298,7 @@ case class Reverse(child: Expression) extends 
UnaryExpression with ImplicitCastI
     val i = ctx.freshName("i")
     val j = ctx.freshName("j")
 
+    val elementType = dataType.asInstanceOf[ArrayType].elementType
     val initialization = CodeGenerator.createArrayData(
       arrayData, elementType, numElements, s" $prettyName failed.")
     val assignment = CodeGenerator.createArrayAssignment(
@@ -2164,9 +2169,11 @@ case class ElementAt(left: Expression, right: 
Expression) extends GetMapValueUti
 
   override def nullable: Boolean = true
 
-  override def nullSafeEval(value: Any, ordinal: Any): Any = {
-    left.dataType match {
-      case _: ArrayType =>
+  override def nullSafeEval(value: Any, ordinal: Any): Any = 
doElementAt(value, ordinal)
+
+  @transient private lazy val doElementAt: (Any, Any) => Any = left.dataType 
match {
+    case _: ArrayType =>
+      (value, ordinal) => {
         val array = value.asInstanceOf[ArrayData]
         val index = ordinal.asInstanceOf[Int]
         if (array.numElements() < math.abs(index)) {
@@ -2185,9 +2192,9 @@ case class ElementAt(left: Expression, right: Expression) 
extends GetMapValueUti
             array.get(idx, dataType)
           }
         }
-      case _: MapType =>
-        getValueEval(value, ordinal, mapKeyType, ordering)
-    }
+      }
+    case _: MapType =>
+      (value, ordinal) => getValueEval(value, ordinal, mapKeyType, ordering)
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@@ -2278,33 +2285,41 @@ case class Concat(children: Seq[Expression]) extends 
ComplexTypeMergingExpressio
 
   override def foldable: Boolean = children.forall(_.foldable)
 
-  override def eval(input: InternalRow): Any = dataType match {
+  override def eval(input: InternalRow): Any = doConcat(input)
+
+  @transient private lazy val doConcat: InternalRow => Any = dataType match {
     case BinaryType =>
-      val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]])
-      ByteArray.concat(inputs: _*)
+      input => {
+        val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]])
+        ByteArray.concat(inputs: _*)
+      }
     case StringType =>
-      val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
-      UTF8String.concat(inputs : _*)
+      input => {
+        val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
+        UTF8String.concat(inputs: _*)
+      }
     case ArrayType(elementType, _) =>
-      val inputs = children.toStream.map(_.eval(input))
-      if (inputs.contains(null)) {
-        null
-      } else {
-        val arrayData = inputs.map(_.asInstanceOf[ArrayData])
-        val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + 
ad.numElements())
-        if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
-          throw new RuntimeException(s"Unsuccessful try to concat arrays with 
$numberOfElements" +
-            " elements due to exceeding the array size limit " +
-            ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".")
-        }
-        val finalData = new Array[AnyRef](numberOfElements.toInt)
-        var position = 0
-        for(ad <- arrayData) {
-          val arr = ad.toObjectArray(elementType)
-          Array.copy(arr, 0, finalData, position, arr.length)
-          position += arr.length
+      input => {
+        val inputs = children.toStream.map(_.eval(input))
+        if (inputs.contains(null)) {
+          null
+        } else {
+          val arrayData = inputs.map(_.asInstanceOf[ArrayData])
+          val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + 
ad.numElements())
+          if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+            throw new RuntimeException(s"Unsuccessful try to concat arrays 
with $numberOfElements" +
+              " elements due to exceeding the array size limit " +
+              ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".")
+          }
+          val finalData = new Array[AnyRef](numberOfElements.toInt)
+          var position = 0
+          for (ad <- arrayData) {
+            val arr = ad.toObjectArray(elementType)
+            Array.copy(arr, 0, finalData, position, arr.length)
+            position += arr.length
+          }
+          new GenericArrayData(finalData)
         }
-        new GenericArrayData(finalData)
       }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to