Repository: spark Updated Branches: refs/heads/master 0cea9e3cd -> ab1029fb8
[SPARK-23912][SQL][FOLLOWUP] Refactor ArrayDistinct ## What changes were proposed in this pull request? This PR simplified code generation for `ArrayDistinct`. #21966 enabled code generation only if the type can be specialized by the hash set. This PR follows this strategy. Optimization of null handling will be implemented in #21912. ## How was this patch tested? Existing UTs Closes #22044 from kiszk/SPARK-23912-follow. Authored-by: Kazuaki Ishizaki <ishiz...@jp.ibm.com> Signed-off-by: Takuya UESHIN <ues...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ab1029fb Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ab1029fb Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ab1029fb Branch: refs/heads/master Commit: ab1029fb8aae586e3af1238048e8b3dcfeb096f4 Parents: 0cea9e3 Author: Kazuaki Ishizaki <ishiz...@jp.ibm.com> Authored: Fri Aug 10 15:41:59 2018 +0900 Committer: Takuya UESHIN <ues...@databricks.com> Committed: Fri Aug 10 15:41:59 2018 +0900 ---------------------------------------------------------------------- .../expressions/collectionOperations.scala | 215 ++++++------------- 1 file changed, 61 insertions(+), 154 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/ab1029fb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index b37fdc6..5e3449d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -3410,6 +3410,28 @@ case class ArrayDistinct(child: Expression) case _ => false } + @transient protected lazy val canUseSpecializedHashSet = elementType match { + case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true + case _ => false + } + + @transient protected lazy val (hsPostFix, hsTypeName) = { + val ptName = CodeGenerator.primitiveTypeName(elementType) + elementType match { + // we cast byte/short to int when writing to the hash set. + case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int") + case LongType => ("$mcJ$sp", ptName) + case FloatType => ("$mcF$sp", ptName) + case DoubleType => ("$mcD$sp", ptName) + } + } + + // we cast byte/short to int when writing to the hash set. + @transient protected lazy val hsValueCast = elementType match { + case ByteType | ShortType => "(int) " + case _ => "" + } + override def nullSafeEval(array: Any): Any = { val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType) if (elementTypeSupportEquals) { @@ -3442,17 +3464,15 @@ case class ArrayDistinct(child: Expression) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - nullSafeCodeGen(ctx, ev, (array) => { - val i = ctx.freshName("i") - val j = ctx.freshName("j") - val sizeOfDistinctArray = ctx.freshName("sizeOfDistinctArray") - val getValue1 = CodeGenerator.getValue(array, elementType, i) - val getValue2 = CodeGenerator.getValue(array, elementType, j) - val foundNullElement = ctx.freshName("foundNullElement") - val openHashSet = classOf[OpenHashSet[_]].getName - val hs = ctx.freshName("hs") - val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()" - if (elementTypeSupportEquals) { + if (canUseSpecializedHashSet) { + nullSafeCodeGen(ctx, ev, (array) => { + val i = ctx.freshName("i") + val sizeOfDistinctArray = ctx.freshName("sizeOfDistinctArray") + val foundNullElement = ctx.freshName("foundNullElement") + val openHashSet = classOf[OpenHashSet[_]].getName + val hs = ctx.freshName("hs") + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" + val getValue = CodeGenerator.getValue(array, elementType, i) s""" |int $sizeOfDistinctArray = 0; |boolean $foundNullElement = false; @@ -3461,53 +3481,26 @@ case class ArrayDistinct(child: Expression) | if ($array.isNullAt($i)) { | $foundNullElement = true; | } else { - | $hs.add($getValue1); + | $hs.add$hsPostFix($hsValueCast$getValue); | } |} |$sizeOfDistinctArray = $hs.size() + ($foundNullElement ? 1 : 0); |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)} """.stripMargin - } else { - s""" - |int $sizeOfDistinctArray = 0; - |boolean $foundNullElement = false; - |for (int $i = 0; $i < $array.numElements(); $i ++) { - | if ($array.isNullAt($i)) { - | if (!($foundNullElement)) { - | $sizeOfDistinctArray = $sizeOfDistinctArray + 1; - | $foundNullElement = true; - | } - | } else { - | int $j; - | for ($j = 0; $j < $i; $j ++) { - | if (!$array.isNullAt($j) && ${ctx.genEqual(elementType, getValue1, getValue2)}) { - | break; - | } - | } - | if ($i == $j) { - | $sizeOfDistinctArray = $sizeOfDistinctArray + 1; - | } - | } - |} - | - |${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)} - """.stripMargin - } - }) + }) + } else { + nullSafeCodeGen(ctx, ev, (array) => { + val expr = ctx.addReferenceObj("arrayDistinctExpr", this) + s"${ev.value} = (ArrayData)$expr.nullSafeEval($array);" + }) + } } private def setNull( - isPrimitive: Boolean, foundNullElement: String, distinctArray: String, pos: String): String = { - val setNullValue = - if (!isPrimitive) { - s"$distinctArray[$pos] = null"; - } else { - s"$distinctArray.setNullAt($pos)"; - } - + val setNullValue = s"$distinctArray.setNullAt($pos)" s""" |if (!($foundNullElement)) { | $setNullValue; @@ -3517,57 +3510,16 @@ case class ArrayDistinct(child: Expression) """.stripMargin } - private def setNotNullValue(isPrimitive: Boolean, - distinctArray: String, - pos: String, - getValue1: String, - primitiveValueTypeName: String): String = { - if (!isPrimitive) { - s"$distinctArray[$pos] = $getValue1"; - } else { - s"$distinctArray.set$primitiveValueTypeName($pos, $getValue1)"; - } - } - - private def setValueForFastEval( - isPrimitive: Boolean, + private def setValue( hs: String, distinctArray: String, pos: String, getValue1: String, primitiveValueTypeName: String): String = { - val setValue = setNotNullValue(isPrimitive, - distinctArray, pos, getValue1, primitiveValueTypeName) s""" - |if (!($hs.contains($getValue1))) { - | $hs.add($getValue1); - | $setValue; - | $pos = $pos + 1; - |} - """.stripMargin - } - - private def setValueForBruteForceEval( - isPrimitive: Boolean, - i: String, - j: String, - inputArray: String, - distinctArray: String, - pos: String, - getValue1: String, - isEqual: String, - primitiveValueTypeName: String): String = { - val setValue = setNotNullValue(isPrimitive, - distinctArray, pos, getValue1, primitiveValueTypeName) - s""" - |int $j; - |for ($j = 0; $j < $i; $j ++) { - | if (!$inputArray.isNullAt($j) && $isEqual) { - | break; - | } - |} - |if ($i == $j) { - | $setValue; + |if (!($hs.contains$hsPostFix($hsValueCast$getValue1))) { + | $hs.add$hsPostFix($hsValueCast$getValue1); + | $distinctArray.set$primitiveValueTypeName($pos, $getValue1); | $pos = $pos + 1; |} """.stripMargin @@ -3580,73 +3532,28 @@ case class ArrayDistinct(child: Expression) size: String): String = { val distinctArray = ctx.freshName("distinctArray") val i = ctx.freshName("i") - val j = ctx.freshName("j") val pos = ctx.freshName("pos") val getValue1 = CodeGenerator.getValue(inputArray, elementType, i) - val getValue2 = CodeGenerator.getValue(inputArray, elementType, j) - val isEqual = ctx.genEqual(elementType, getValue1, getValue2) val foundNullElement = ctx.freshName("foundNullElement") val hs = ctx.freshName("hs") val openHashSet = classOf[OpenHashSet[_]].getName - if (!CodeGenerator.isPrimitiveType(elementType)) { - val arrayClass = classOf[GenericArrayData].getName - val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()" - val setNullForNonPrimitive = - setNull(false, foundNullElement, distinctArray, pos) - if (elementTypeSupportEquals) { - val setValueForFast = setValueForFastEval(false, hs, distinctArray, pos, getValue1, "") - s""" - |int $pos = 0; - |Object[] $distinctArray = new Object[$size]; - |boolean $foundNullElement = false; - |$openHashSet $hs = new $openHashSet($classTag); - |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { - | if ($inputArray.isNullAt($i)) { - | $setNullForNonPrimitive; - | } else { - | $setValueForFast; - | } - |} - |${ev.value} = new $arrayClass($distinctArray); - """.stripMargin - } else { - val setValueForBruteForce = setValueForBruteForceEval( - false, i, j, inputArray, distinctArray, pos, getValue1, isEqual, "") - s""" - |int $pos = 0; - |Object[] $distinctArray = new Object[$size]; - |boolean $foundNullElement = false; - |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { - | if ($inputArray.isNullAt($i)) { - | $setNullForNonPrimitive; - | } else { - | $setValueForBruteForce; - | } - |} - |${ev.value} = new $arrayClass($distinctArray); - """.stripMargin - } - } else { - val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) - val setNullForPrimitive = setNull(true, foundNullElement, distinctArray, pos) - val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$primitiveValueTypeName()" - val setValueForFast = - setValueForFastEval(true, hs, distinctArray, pos, getValue1, primitiveValueTypeName) - s""" - |${ctx.createUnsafeArray(distinctArray, size, elementType, s" $prettyName failed.")} - |int $pos = 0; - |boolean $foundNullElement = false; - |$openHashSet $hs = new $openHashSet($classTag); - |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { - | if ($inputArray.isNullAt($i)) { - | $setNullForPrimitive; - | } else { - | $setValueForFast; - | } - |} - |${ev.value} = $distinctArray; - """.stripMargin - } + val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" + + s""" + |${ctx.createUnsafeArray(distinctArray, size, elementType, s" $prettyName failed.")} + |int $pos = 0; + |boolean $foundNullElement = false; + |$openHashSet $hs = new $openHashSet($classTag); + |for (int $i = 0; $i < $inputArray.numElements(); $i ++) { + | if ($inputArray.isNullAt($i)) { + | ${setNull(foundNullElement, distinctArray, pos)} + | } else { + | ${setValue(hs, distinctArray, pos, getValue1, primitiveValueTypeName)} + | } + |} + |${ev.value} = $distinctArray; + """.stripMargin } override def prettyName: String = "array_distinct" --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org