Repository: spark
Updated Branches:
  refs/heads/master 0cea9e3cd -> ab1029fb8


[SPARK-23912][SQL][FOLLOWUP] Refactor ArrayDistinct

## What changes were proposed in this pull request?

This PR simplified code generation for `ArrayDistinct`. #21966 enabled code 
generation only if the type can be specialized by the hash set. This PR follows 
this strategy.

Optimization of null handling will be implemented in #21912.

## How was this patch tested?

Existing UTs

Closes #22044 from kiszk/SPARK-23912-follow.

Authored-by: Kazuaki Ishizaki <ishiz...@jp.ibm.com>
Signed-off-by: Takuya UESHIN <ues...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ab1029fb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ab1029fb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ab1029fb

Branch: refs/heads/master
Commit: ab1029fb8aae586e3af1238048e8b3dcfeb096f4
Parents: 0cea9e3
Author: Kazuaki Ishizaki <ishiz...@jp.ibm.com>
Authored: Fri Aug 10 15:41:59 2018 +0900
Committer: Takuya UESHIN <ues...@databricks.com>
Committed: Fri Aug 10 15:41:59 2018 +0900

----------------------------------------------------------------------
 .../expressions/collectionOperations.scala      | 215 ++++++-------------
 1 file changed, 61 insertions(+), 154 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ab1029fb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index b37fdc6..5e3449d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -3410,6 +3410,28 @@ case class ArrayDistinct(child: Expression)
     case _ => false
   }
 
+  @transient protected lazy val canUseSpecializedHashSet = elementType match {
+    case ByteType | ShortType | IntegerType | LongType | FloatType | 
DoubleType => true
+    case _ => false
+  }
+
+  @transient protected lazy val (hsPostFix, hsTypeName) = {
+    val ptName = CodeGenerator.primitiveTypeName(elementType)
+    elementType match {
+      // we cast byte/short to int when writing to the hash set.
+      case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int")
+      case LongType => ("$mcJ$sp", ptName)
+      case FloatType => ("$mcF$sp", ptName)
+      case DoubleType => ("$mcD$sp", ptName)
+    }
+  }
+
+  // we cast byte/short to int when writing to the hash set.
+  @transient protected lazy val hsValueCast = elementType match {
+    case ByteType | ShortType => "(int) "
+    case _ => ""
+  }
+
   override def nullSafeEval(array: Any): Any = {
     val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
     if (elementTypeSupportEquals) {
@@ -3442,17 +3464,15 @@ case class ArrayDistinct(child: Expression)
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
-    nullSafeCodeGen(ctx, ev, (array) => {
-      val i = ctx.freshName("i")
-      val j = ctx.freshName("j")
-      val sizeOfDistinctArray = ctx.freshName("sizeOfDistinctArray")
-      val getValue1 = CodeGenerator.getValue(array, elementType, i)
-      val getValue2 = CodeGenerator.getValue(array, elementType, j)
-      val foundNullElement = ctx.freshName("foundNullElement")
-      val openHashSet = classOf[OpenHashSet[_]].getName
-      val hs = ctx.freshName("hs")
-      val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()"
-      if (elementTypeSupportEquals) {
+    if (canUseSpecializedHashSet) {
+      nullSafeCodeGen(ctx, ev, (array) => {
+        val i = ctx.freshName("i")
+        val sizeOfDistinctArray = ctx.freshName("sizeOfDistinctArray")
+        val foundNullElement = ctx.freshName("foundNullElement")
+        val openHashSet = classOf[OpenHashSet[_]].getName
+        val hs = ctx.freshName("hs")
+        val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
+        val getValue = CodeGenerator.getValue(array, elementType, i)
         s"""
            |int $sizeOfDistinctArray = 0;
            |boolean $foundNullElement = false;
@@ -3461,53 +3481,26 @@ case class ArrayDistinct(child: Expression)
            |  if ($array.isNullAt($i)) {
            |    $foundNullElement = true;
            |  } else {
-           |    $hs.add($getValue1);
+           |    $hs.add$hsPostFix($hsValueCast$getValue);
            |  }
            |}
            |$sizeOfDistinctArray = $hs.size() + ($foundNullElement ? 1 : 0);
            |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)}
          """.stripMargin
-      } else {
-        s"""
-           |int $sizeOfDistinctArray = 0;
-           |boolean $foundNullElement = false;
-           |for (int $i = 0; $i < $array.numElements(); $i ++) {
-           |  if ($array.isNullAt($i)) {
-           |     if (!($foundNullElement)) {
-           |       $sizeOfDistinctArray = $sizeOfDistinctArray + 1;
-           |       $foundNullElement = true;
-           |     }
-           |  } else {
-           |    int $j;
-           |    for ($j = 0; $j < $i; $j ++) {
-           |      if (!$array.isNullAt($j) && ${ctx.genEqual(elementType, 
getValue1, getValue2)}) {
-           |        break;
-           |      }
-           |    }
-           |    if ($i == $j) {
-           |     $sizeOfDistinctArray = $sizeOfDistinctArray + 1;
-           |    }
-           |  }
-           |}
-           |
-           |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)}
-         """.stripMargin
-      }
-    })
+      })
+    } else {
+      nullSafeCodeGen(ctx, ev, (array) => {
+        val expr = ctx.addReferenceObj("arrayDistinctExpr", this)
+        s"${ev.value} = (ArrayData)$expr.nullSafeEval($array);"
+      })
+    }
   }
 
   private def setNull(
-      isPrimitive: Boolean,
       foundNullElement: String,
       distinctArray: String,
       pos: String): String = {
-    val setNullValue =
-      if (!isPrimitive) {
-        s"$distinctArray[$pos] = null";
-      } else {
-        s"$distinctArray.setNullAt($pos)";
-      }
-
+    val setNullValue = s"$distinctArray.setNullAt($pos)"
     s"""
        |if (!($foundNullElement)) {
        |  $setNullValue;
@@ -3517,57 +3510,16 @@ case class ArrayDistinct(child: Expression)
     """.stripMargin
   }
 
-  private def setNotNullValue(isPrimitive: Boolean,
-      distinctArray: String,
-      pos: String,
-      getValue1: String,
-      primitiveValueTypeName: String): String = {
-    if (!isPrimitive) {
-      s"$distinctArray[$pos] = $getValue1";
-    } else {
-      s"$distinctArray.set$primitiveValueTypeName($pos, $getValue1)";
-    }
-  }
-
-  private def setValueForFastEval(
-      isPrimitive: Boolean,
+  private def setValue(
       hs: String,
       distinctArray: String,
       pos: String,
       getValue1: String,
       primitiveValueTypeName: String): String = {
-    val setValue = setNotNullValue(isPrimitive,
-      distinctArray, pos, getValue1, primitiveValueTypeName)
     s"""
-       |if (!($hs.contains($getValue1))) {
-       |  $hs.add($getValue1);
-       |  $setValue;
-       |  $pos = $pos + 1;
-       |}
-    """.stripMargin
-  }
-
-  private def setValueForBruteForceEval(
-      isPrimitive: Boolean,
-      i: String,
-      j: String,
-      inputArray: String,
-      distinctArray: String,
-      pos: String,
-      getValue1: String,
-      isEqual: String,
-      primitiveValueTypeName: String): String = {
-    val setValue = setNotNullValue(isPrimitive,
-      distinctArray, pos, getValue1, primitiveValueTypeName)
-    s"""
-       |int $j;
-       |for ($j = 0; $j < $i; $j ++) {
-       |  if (!$inputArray.isNullAt($j) && $isEqual) {
-       |    break;
-       |  }
-       |}
-       |if ($i == $j) {
-       |  $setValue;
+       |if (!($hs.contains$hsPostFix($hsValueCast$getValue1))) {
+       |  $hs.add$hsPostFix($hsValueCast$getValue1);
+       |  $distinctArray.set$primitiveValueTypeName($pos, $getValue1);
        |  $pos = $pos + 1;
        |}
     """.stripMargin
@@ -3580,73 +3532,28 @@ case class ArrayDistinct(child: Expression)
       size: String): String = {
     val distinctArray = ctx.freshName("distinctArray")
     val i = ctx.freshName("i")
-    val j = ctx.freshName("j")
     val pos = ctx.freshName("pos")
     val getValue1 = CodeGenerator.getValue(inputArray, elementType, i)
-    val getValue2 = CodeGenerator.getValue(inputArray, elementType, j)
-    val isEqual = ctx.genEqual(elementType, getValue1, getValue2)
     val foundNullElement = ctx.freshName("foundNullElement")
     val hs = ctx.freshName("hs")
     val openHashSet = classOf[OpenHashSet[_]].getName
-    if (!CodeGenerator.isPrimitiveType(elementType)) {
-      val arrayClass = classOf[GenericArrayData].getName
-      val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()"
-      val setNullForNonPrimitive =
-        setNull(false, foundNullElement, distinctArray, pos)
-      if (elementTypeSupportEquals) {
-        val setValueForFast = setValueForFastEval(false, hs, distinctArray, 
pos, getValue1, "")
-        s"""
-           |int $pos = 0;
-           |Object[] $distinctArray = new Object[$size];
-           |boolean $foundNullElement = false;
-           |$openHashSet $hs = new $openHashSet($classTag);
-           |for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
-           |  if ($inputArray.isNullAt($i)) {
-           |    $setNullForNonPrimitive;
-           |  } else {
-           |    $setValueForFast;
-           |  }
-           |}
-           |${ev.value} = new $arrayClass($distinctArray);
-        """.stripMargin
-      } else {
-        val setValueForBruteForce = setValueForBruteForceEval(
-          false, i, j, inputArray, distinctArray, pos, getValue1, isEqual, "")
-        s"""
-           |int $pos = 0;
-           |Object[] $distinctArray = new Object[$size];
-           |boolean $foundNullElement = false;
-           |for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
-           |  if ($inputArray.isNullAt($i)) {
-           |    $setNullForNonPrimitive;
-           |  } else {
-           |    $setValueForBruteForce;
-           |  }
-           |}
-           |${ev.value} = new $arrayClass($distinctArray);
-       """.stripMargin
-      }
-    } else {
-      val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
-      val setNullForPrimitive = setNull(true, foundNullElement, distinctArray, 
pos)
-      val classTag = 
s"scala.reflect.ClassTag$$.MODULE$$.$primitiveValueTypeName()"
-      val setValueForFast =
-        setValueForFastEval(true, hs, distinctArray, pos, getValue1, 
primitiveValueTypeName)
-      s"""
-         |${ctx.createUnsafeArray(distinctArray, size, elementType, s" 
$prettyName failed.")}
-         |int $pos = 0;
-         |boolean $foundNullElement = false;
-         |$openHashSet $hs = new $openHashSet($classTag);
-         |for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
-         |  if ($inputArray.isNullAt($i)) {
-         |    $setNullForPrimitive;
-         |  } else {
-         |    $setValueForFast;
-         |  }
-         |}
-         |${ev.value} = $distinctArray;
-      """.stripMargin
-    }
+    val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
+    val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
+
+    s"""
+       |${ctx.createUnsafeArray(distinctArray, size, elementType, s" 
$prettyName failed.")}
+       |int $pos = 0;
+       |boolean $foundNullElement = false;
+       |$openHashSet $hs = new $openHashSet($classTag);
+       |for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
+       |  if ($inputArray.isNullAt($i)) {
+       |    ${setNull(foundNullElement, distinctArray, pos)}
+       |  } else {
+       |    ${setValue(hs, distinctArray, pos, getValue1, 
primitiveValueTypeName)}
+       |  }
+       |}
+       |${ev.value} = $distinctArray;
+    """.stripMargin
   }
 
   override def prettyName: String = "array_distinct"


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to