Repository: spark
Updated Branches:
  refs/heads/master 0ab07b357 -> 7b6d36bc9


[SPARK-24871][SQL] Refactor Concat and MapConcat to avoid creating concatenator 
object for each row.

## What changes were proposed in this pull request?

Refactor `Concat` and `MapConcat` to:

- avoid creating concatenator object for each row.
- make `Concat` handle `containsNull` properly.
- make `Concat` shortcut if `null` child is found.

## How was this patch tested?

Added some tests and existing tests.

Author: Takuya UESHIN <[email protected]>

Closes #21824 from ueshin/issues/SPARK-24871/refactor_concat_mapconcat.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7b6d36bc
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7b6d36bc
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7b6d36bc

Branch: refs/heads/master
Commit: 7b6d36bc9ef7a0af5e7461bba31c0e2518e3ce8d
Parents: 0ab07b3
Author: Takuya UESHIN <[email protected]>
Authored: Fri Jul 20 20:08:42 2018 +0800
Committer: Wenchen Fan <[email protected]>
Committed: Fri Jul 20 20:08:42 2018 +0800

----------------------------------------------------------------------
 .../expressions/collectionOperations.scala      | 299 +++++++++++--------
 .../CollectionExpressionsSuite.scala            |  15 +
 2 files changed, 192 insertions(+), 122 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7b6d36bc/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index 9263541..f438748 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -571,16 +571,25 @@ case class MapConcat(children: Seq[Expression]) extends 
ComplexTypeMergingExpres
         |$mapDataClass ${ev.value} = null;
       """.stripMargin
 
-    val assignments = mapCodes.zipWithIndex.map { case (m, i) =>
-      s"""
-         |if (!$hasNullName) {
-         |  ${m.code}
-         |  $argsName[$i] = ${m.value};
-         |  if (${m.isNull}) {
-         |    $hasNullName = true;
-         |  }
-         |}
-       """.stripMargin
+    val assignments = mapCodes.zip(children.map(_.nullable)).zipWithIndex.map {
+      case ((m, true), i) =>
+        s"""
+           |if (!$hasNullName) {
+           |  ${m.code}
+           |  if (!${m.isNull}) {
+           |    $argsName[$i] = ${m.value};
+           |  } else {
+           |    $hasNullName = true;
+           |  }
+           |}
+         """.stripMargin
+      case ((m, false), i) =>
+        s"""
+           |if (!$hasNullName) {
+           |  ${m.code}
+           |  $argsName[$i] = ${m.value};
+           |}
+         """.stripMargin
     }
 
     val codes = ctx.splitExpressionsWithCurrentInputs(
@@ -601,17 +610,21 @@ case class MapConcat(children: Seq[Expression]) extends 
ComplexTypeMergingExpres
     val finKeysName = ctx.freshName("finalKeys")
     val finValsName = ctx.freshName("finalValues")
 
-    val keyConcatenator = if (CodeGenerator.isPrimitiveType(keyType)) {
+    val keyConcat = if (CodeGenerator.isPrimitiveType(keyType)) {
       genCodeForPrimitiveArrays(ctx, keyType, false)
     } else {
       genCodeForNonPrimitiveArrays(ctx, keyType)
     }
 
-    val valueConcatenator = if (CodeGenerator.isPrimitiveType(valueType)) {
-      genCodeForPrimitiveArrays(ctx, valueType, dataType.valueContainsNull)
-    } else {
-      genCodeForNonPrimitiveArrays(ctx, valueType)
-    }
+    val valueConcat =
+      if (valueType.sameType(keyType) &&
+          !(CodeGenerator.isPrimitiveType(valueType) && 
dataType.valueContainsNull)) {
+        keyConcat
+      } else if (CodeGenerator.isPrimitiveType(valueType)) {
+        genCodeForPrimitiveArrays(ctx, valueType, dataType.valueContainsNull)
+      } else {
+        genCodeForNonPrimitiveArrays(ctx, valueType)
+      }
 
     val keyArgsName = ctx.freshName("keyArgs")
     val valArgsName = ctx.freshName("valArgs")
@@ -633,9 +646,9 @@ case class MapConcat(children: Seq[Expression]) extends 
ComplexTypeMergingExpres
         |       $numElementsName + " elements due to exceeding the map size 
limit " +
         |       "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
         |  }
-        |  $arrayDataClass $finKeysName = $keyConcatenator.concat($keyArgsName,
+        |  $arrayDataClass $finKeysName = $keyConcat($keyArgsName,
         |    (int) $numElementsName);
-        |  $arrayDataClass $finValsName = 
$valueConcatenator.concat($valArgsName,
+        |  $arrayDataClass $finValsName = $valueConcat($valArgsName,
         |    (int) $numElementsName);
         |  ${ev.value} = new $arrayBasedMapDataClass($finKeysName, 
$finValsName);
         |}
@@ -677,20 +690,23 @@ case class MapConcat(children: Seq[Expression]) extends 
ComplexTypeMergingExpres
       setterCode1
     }
 
-    s"""
-       |new Object() {
-       |  public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, 
int $numElemName) {
-       |    ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" 
$prettyName failed.")}
-       |    int $counter = 0;
-       |    for (int y = 0; y < ${children.length}; y++) {
-       |      for (int z = 0; z < $argsName[y].numElements(); z++) {
-       |        $setterCode
-       |        $counter++;
-       |      }
-       |    }
-       |    return $arrayData;
-       |  }
-       |}""".stripMargin.stripPrefix("\n")
+    val concat = ctx.freshName("concat")
+    val concatDef =
+      s"""
+         |private ArrayData $concat(ArrayData[] $argsName, int $numElemName) {
+         |  ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" 
$prettyName failed.")}
+         |  int $counter = 0;
+         |  for (int y = 0; y < ${children.length}; y++) {
+         |    for (int z = 0; z < $argsName[y].numElements(); z++) {
+         |      $setterCode
+         |      $counter++;
+         |    }
+         |  }
+         |  return $arrayData;
+         |}
+       """.stripMargin
+
+    ctx.addNewFunction(concat, concatDef)
   }
 
   private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: 
DataType): String = {
@@ -700,20 +716,23 @@ case class MapConcat(children: Seq[Expression]) extends 
ComplexTypeMergingExpres
     val argsName = ctx.freshName("args")
     val numElemName = ctx.freshName("numElements")
 
-    s"""
-       |new Object() {
-       |  public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, 
int $numElemName) {;
-       |    Object[] $arrayData = new Object[$numElemName];
-       |    int $counter = 0;
-       |    for (int y = 0; y < ${children.length}; y++) {
-       |      for (int z = 0; z < $argsName[y].numElements(); z++) {
-       |        $arrayData[$counter] = 
${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")};
-       |        $counter++;
-       |      }
-       |    }
-       |    return new $genericArrayClass($arrayData);
-       |  }
-       |}""".stripMargin.stripPrefix("\n")
+    val concat = ctx.freshName("concat")
+    val concatDef =
+      s"""
+         |private ArrayData $concat(ArrayData[] $argsName, int $numElemName) {
+         |  Object[] $arrayData = new Object[$numElemName];
+         |  int $counter = 0;
+         |  for (int y = 0; y < ${children.length}; y++) {
+         |    for (int z = 0; z < $argsName[y].numElements(); z++) {
+         |      $arrayData[$counter] = 
${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")};
+         |      $counter++;
+         |    }
+         |  }
+         |  return new $genericArrayClass($arrayData);
+         |}
+       """.stripMargin
+
+    ctx.addNewFunction(concat, concatDef)
   }
 
   override def prettyName: String = "map_concat"
@@ -2270,39 +2289,67 @@ case class Concat(children: Seq[Expression]) extends 
ComplexTypeMergingExpressio
   override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
     val evals = children.map(_.genCode(ctx))
     val args = ctx.freshName("args")
+    val hasNull = ctx.freshName("hasNull")
 
-    val inputs = evals.zipWithIndex.map { case (eval, index) =>
-      s"""
-        ${eval.code}
-        if (!${eval.isNull}) {
-          $args[$index] = ${eval.value};
-        }
-      """
+    val inputs = evals.zip(children.map(_.nullable)).zipWithIndex.map {
+      case ((eval, true), index) =>
+        s"""
+           |if (!$hasNull) {
+           |  ${eval.code}
+           |  if (!${eval.isNull}) {
+           |    $args[$index] = ${eval.value};
+           |  } else {
+           |    $hasNull = true;
+           |  }
+           |}
+         """.stripMargin
+      case ((eval, false), index) =>
+        s"""
+           |if (!$hasNull) {
+           |  ${eval.code}
+           |  $args[$index] = ${eval.value};
+           |}
+         """.stripMargin
     }
 
-    val (concatenator, initCode) = dataType match {
+    val codes = ctx.splitExpressionsWithCurrentInputs(
+      expressions = inputs,
+      funcName = "valueConcat",
+      extraArguments = (s"$javaType[]", args) :: ("boolean", hasNull) :: Nil,
+      returnType = "boolean",
+      makeSplitFunction = body =>
+        s"""
+           |$body
+           |return $hasNull;
+         """.stripMargin,
+      foldFunctions = _.map(funcCall => s"$hasNull = 
$funcCall;").mkString("\n")
+    )
+
+    val (concat, initCode) = dataType match {
       case BinaryType =>
-        (classOf[ByteArray].getName, s"byte[][] $args = new 
byte[${evals.length}][];")
+        (s"${classOf[ByteArray].getName}.concat", s"byte[][] $args = new 
byte[${evals.length}][];")
       case StringType =>
-        ("UTF8String", s"UTF8String[] $args = new 
UTF8String[${evals.length}];")
-      case ArrayType(elementType, _) =>
-        val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) 
{
-          genCodeForPrimitiveArrays(ctx, elementType)
+        ("UTF8String.concat", s"UTF8String[] $args = new 
UTF8String[${evals.length}];")
+      case ArrayType(elementType, containsNull) =>
+        val concat = if (CodeGenerator.isPrimitiveType(elementType)) {
+          genCodeForPrimitiveArrays(ctx, elementType, containsNull)
         } else {
           genCodeForNonPrimitiveArrays(ctx, elementType)
         }
-        (arrayConcatClass, s"ArrayData[] $args = new 
ArrayData[${evals.length}];")
+        (concat, s"ArrayData[] $args = new ArrayData[${evals.length}];")
     }
-    val codes = ctx.splitExpressionsWithCurrentInputs(
-      expressions = inputs,
-      funcName = "valueConcat",
-      extraArguments = (s"$javaType[]", args) :: Nil)
-    ev.copy(code"""
-      $initCode
-      $codes
-      $javaType ${ev.value} = $concatenator.concat($args);
-      boolean ${ev.isNull} = ${ev.value} == null;
-    """)
+
+    ev.copy(code =
+      code"""
+         |boolean $hasNull = false;
+         |$initCode
+         |$codes
+         |$javaType ${ev.value} = null;
+         |if (!$hasNull) {
+         |  ${ev.value} = $concat($args);
+         |}
+         |boolean ${ev.isNull} = ${ev.value} == null;
+       """.stripMargin)
   }
 
   private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, 
String) = {
@@ -2322,19 +2369,10 @@ case class Concat(children: Seq[Expression]) extends 
ComplexTypeMergingExpressio
     (code, numElements)
   }
 
-  private def nullArgumentProtection() : String = {
-    if (nullable) {
-      s"""
-         |for (int z = 0; z < ${children.length}; z++) {
-         |  if (args[z] == null) return null;
-         |}
-       """.stripMargin
-    } else {
-      ""
-    }
-  }
-
-  private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: 
DataType): String = {
+  private def genCodeForPrimitiveArrays(
+      ctx: CodegenContext,
+      elementType: DataType,
+      checkForNull: Boolean): String = {
     val counter = ctx.freshName("counter")
     val arrayData = ctx.freshName("arrayData")
 
@@ -2342,29 +2380,44 @@ case class Concat(children: Seq[Expression]) extends 
ComplexTypeMergingExpressio
 
     val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
 
-    s"""
-       |new Object() {
-       |  public ArrayData concat($javaType[] args) {
-       |    ${nullArgumentProtection()}
-       |    $numElemCode
-       |    ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" 
$prettyName failed.")}
-       |    int $counter = 0;
-       |    for (int y = 0; y < ${children.length}; y++) {
-       |      for (int z = 0; z < args[y].numElements(); z++) {
-       |        if (args[y].isNullAt(z)) {
-       |          $arrayData.setNullAt($counter);
-       |        } else {
-       |          $arrayData.set$primitiveValueTypeName(
-       |            $counter,
-       |            ${CodeGenerator.getValue(s"args[y]", elementType, "z")}
-       |          );
-       |        }
-       |        $counter++;
-       |      }
-       |    }
-       |    return $arrayData;
-       |  }
-       |}""".stripMargin.stripPrefix("\n")
+    val setterCode =
+      s"""
+         |$arrayData.set$primitiveValueTypeName(
+         |  $counter,
+         |  ${CodeGenerator.getValue(s"args[y]", elementType, "z")}
+         |);
+       """.stripMargin
+
+    val nullSafeSetterCode = if (checkForNull) {
+      s"""
+         |if (args[y].isNullAt(z)) {
+         |  $arrayData.setNullAt($counter);
+         |} else {
+         |  $setterCode
+         |}
+       """.stripMargin
+    } else {
+      setterCode
+    }
+
+    val concat = ctx.freshName("concat")
+    val concatDef =
+      s"""
+         |private ArrayData $concat(ArrayData[] args) {
+         |  $numElemCode
+         |  ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" 
$prettyName failed.")}
+         |  int $counter = 0;
+         |  for (int y = 0; y < ${children.length}; y++) {
+         |    for (int z = 0; z < args[y].numElements(); z++) {
+         |      $nullSafeSetterCode
+         |      $counter++;
+         |    }
+         |  }
+         |  return $arrayData;
+         |}
+       """.stripMargin
+
+    ctx.addNewFunction(concat, concatDef)
   }
 
   private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: 
DataType): String = {
@@ -2374,22 +2427,24 @@ case class Concat(children: Seq[Expression]) extends 
ComplexTypeMergingExpressio
 
     val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)
 
-    s"""
-       |new Object() {
-       |  public ArrayData concat($javaType[] args) {
-       |    ${nullArgumentProtection()}
-       |    $numElemCode
-       |    Object[] $arrayData = new Object[(int)$numElemName];
-       |    int $counter = 0;
-       |    for (int y = 0; y < ${children.length}; y++) {
-       |      for (int z = 0; z < args[y].numElements(); z++) {
-       |        $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", 
elementType, "z")};
-       |        $counter++;
-       |      }
-       |    }
-       |    return new $genericArrayClass($arrayData);
-       |  }
-       |}""".stripMargin.stripPrefix("\n")
+    val concat = ctx.freshName("concat")
+    val concatDef =
+      s"""
+         |private ArrayData $concat(ArrayData[] args) {
+         |  $numElemCode
+         |  Object[] $arrayData = new Object[(int)$numElemName];
+         |  int $counter = 0;
+         |  for (int y = 0; y < ${children.length}; y++) {
+         |    for (int z = 0; z < args[y].numElements(); z++) {
+         |      $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", 
elementType, "z")};
+         |      $counter++;
+         |    }
+         |  }
+         |  return new $genericArrayClass($arrayData);
+         |}
+       """.stripMargin
+
+    ctx.addNewFunction(concat, concatDef)
   }
 
   override def toString: String = s"concat(${children.mkString(", ")})"

http://git-wip-us.apache.org/repos/asf/spark/blob/7b6d36bc/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
index f1e3bd0..c7f0da7 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
@@ -125,6 +125,12 @@ class CollectionExpressionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
       valueContainsNull = false))
     val m12 = Literal.create(Map(3 -> "3", 4 -> "4"), MapType(IntegerType, 
StringType,
       valueContainsNull = false))
+    val m13 = Literal.create(Map(1 -> 2, 3 -> 4),
+      MapType(IntegerType, IntegerType, valueContainsNull = false))
+    val m14 = Literal.create(Map(5 -> 6),
+      MapType(IntegerType, IntegerType, valueContainsNull = false))
+    val m15 = Literal.create(Map(7 -> null),
+      MapType(IntegerType, IntegerType, valueContainsNull = true))
     val mNull = Literal.create(null, MapType(StringType, StringType))
 
     // overlapping maps
@@ -188,6 +194,12 @@ class CollectionExpressionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
       )
     )
 
+    // both keys and value are primitive and valueContainsNull = false
+    checkEvaluation(MapConcat(Seq(m13, m14)), Map(1 -> 2, 3 -> 4, 5 -> 6))
+
+    // both keys and value are primitive and valueContainsNull = true
+    checkEvaluation(MapConcat(Seq(m13, m15)), Map(1 -> 2, 3 -> 4, 7 -> null))
+
     // null map
     checkEvaluation(MapConcat(Seq(m0, mNull)), null)
     checkEvaluation(MapConcat(Seq(mNull, m0)), null)
@@ -1121,6 +1133,9 @@ class CollectionExpressionsSuite extends SparkFunSuite 
with ExpressionEvalHelper
       ArrayType(ArrayType(StringType, containsNull = false), containsNull = 
false))
     assert(Concat(Seq(aa0, aa2)).dataType ===
       ArrayType(ArrayType(StringType, containsNull = true), containsNull = 
true))
+
+    // force split expressions for input in generated code
+    checkEvaluation(Concat(Seq.fill(100)(ai0)), Seq.fill(100)(Seq(1, 2, 
3)).flatten)
   }
 
   test("Flatten") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to