Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/21937#discussion_r207767113 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -3698,230 +3767,162 @@ object ArraySetLike { """, since = "2.4.0") case class ArrayUnion(left: Expression, right: Expression) extends ArraySetLike - with ComplexTypeMergingExpression { - var hsInt: OpenHashSet[Int] = _ - var hsLong: OpenHashSet[Long] = _ - - def assignInt(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { - val elem = array.getInt(idx) - if (!hsInt.contains(elem)) { - if (resultArray != null) { - resultArray.setInt(pos, elem) - } - hsInt.add(elem) - true - } else { - false - } - } - - def assignLong(array: ArrayData, idx: Int, resultArray: ArrayData, pos: Int): Boolean = { - val elem = array.getLong(idx) - if (!hsLong.contains(elem)) { - if (resultArray != null) { - resultArray.setLong(pos, elem) - } - hsLong.add(elem) - true - } else { - false - } - } + with ComplexTypeMergingExpression { - def evalIntLongPrimitiveType( - array1: ArrayData, - array2: ArrayData, - resultArray: ArrayData, - isLongType: Boolean): Int = { - // store elements into resultArray - var nullElementSize = 0 - var pos = 0 - Seq(array1, array2).foreach { array => - var i = 0 - while (i < array.numElements()) { - val size = if (!isLongType) hsInt.size else hsLong.size - if (size + nullElementSize > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - ArraySetLike.throwUnionLengthOverflowException(size) - } - if (array.isNullAt(i)) { - if (nullElementSize == 0) { - if (resultArray != null) { - resultArray.setNullAt(pos) + @transient lazy val evalUnion: (ArrayData, ArrayData) => ArrayData = { + if (elementTypeSupportEquals) { + (array1, array2) => + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + val hs = new OpenHashSet[Any] + var foundNullElement = false + Seq(array1, array2).foreach { array => + var i = 0 + while (i < array.numElements()) { + if (array.isNullAt(i)) { + if (!foundNullElement) { + arrayBuffer += null + foundNullElement = true + } + } else { + val elem = array.get(i, elementType) + if (!hs.contains(elem)) { + if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size) + } + arrayBuffer += elem + hs.add(elem) + } } - pos += 1 - nullElementSize = 1 + i += 1 } - } else { - val assigned = if (!isLongType) { - assignInt(array, i, resultArray, pos) + } + new GenericArrayData(arrayBuffer) + } else { + (array1, array2) => + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var alreadyIncludeNull = false + Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => { + var found = false + if (elem == null) { + if (alreadyIncludeNull) { + found = true + } else { + alreadyIncludeNull = true + } } else { - assignLong(array, i, resultArray, pos) + // check elem is already stored in arrayBuffer or not? + var j = 0 + while (!found && j < arrayBuffer.size) { + val va = arrayBuffer(j) + if (va != null && ordering.equiv(va, elem)) { + found = true + } + j = j + 1 + } } - if (assigned) { - pos += 1 + if (!found) { + if (arrayBuffer.length > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { + ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.length) + } + arrayBuffer += elem } - } - i += 1 - } + })) + new GenericArrayData(arrayBuffer) } - pos } override def nullSafeEval(input1: Any, input2: Any): Any = { val array1 = input1.asInstanceOf[ArrayData] val array2 = input2.asInstanceOf[ArrayData] - if (elementTypeSupportEquals) { - elementType match { - case IntegerType => - // avoid boxing of primitive int array elements - // calculate result array size - hsInt = new OpenHashSet[Int] - val elements = evalIntLongPrimitiveType(array1, array2, null, false) - hsInt = new OpenHashSet[Int] - val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( - IntegerType.defaultSize, elements)) { - new GenericArrayData(new Array[Any](elements)) - } else { - UnsafeArrayData.forPrimitiveArray( - Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize) - } - evalIntLongPrimitiveType(array1, array2, resultArray, false) - resultArray - case LongType => - // avoid boxing of primitive long array elements - // calculate result array size - hsLong = new OpenHashSet[Long] - val elements = evalIntLongPrimitiveType(array1, array2, null, true) - hsLong = new OpenHashSet[Long] - val resultArray = if (UnsafeArrayData.shouldUseGenericArrayData( - LongType.defaultSize, elements)) { - new GenericArrayData(new Array[Any](elements)) - } else { - UnsafeArrayData.forPrimitiveArray( - Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize) - } - evalIntLongPrimitiveType(array1, array2, resultArray, true) - resultArray - case _ => - val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] - val hs = new OpenHashSet[Any] - var foundNullElement = false - Seq(array1, array2).foreach { array => - var i = 0 - while (i < array.numElements()) { - if (array.isNullAt(i)) { - if (!foundNullElement) { - arrayBuffer += null - foundNullElement = true - } - } else { - val elem = array.get(i, elementType) - if (!hs.contains(elem)) { - if (arrayBuffer.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { - ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size) - } - arrayBuffer += elem - hs.add(elem) - } - } - i += 1 - } - } - new GenericArrayData(arrayBuffer) - } - } else { - ArrayUnion.unionOrdering(array1, array2, elementType, ordering) - } + evalUnion(array1, array2) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val arrayData = classOf[ArrayData].getName val i = ctx.freshName("i") - val pos = ctx.freshName("pos") val value = ctx.freshName("value") val size = ctx.freshName("size") - val (postFix, openHashElementType, getter, setter, javaTypeName, castOp, arrayBuilder) = - if (elementTypeSupportEquals) { - elementType match { - case ByteType | ShortType | IntegerType | LongType => - val ptName = CodeGenerator.primitiveTypeName(elementType) - val unsafeArray = ctx.freshName("unsafeArray") - (if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp", - if (elementType == LongType) "Long" else "Int", - s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType), - if (elementType == LongType) "(long)" else "(int)", - s""" - |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} - |${ev.value} = $unsafeArray; - """.stripMargin) - case _ => - val genericArrayData = classOf[GenericArrayData].getName - val et = ctx.addReferenceObj("elementType", elementType) - ("", "Object", - s"get($i, $et)", s"update($pos, $value)", "Object", "", - s"${ev.value} = new $genericArrayData(new Object[$size]);") - } - } else { - ("", "", "", "", "", "", "") - } + if (canUseSpecializedHashSet) { + val jt = CodeGenerator.javaType(elementType) + val ptName = CodeGenerator.primitiveTypeName(jt) - nullSafeCodeGen(ctx, ev, (array1, array2) => { - if (openHashElementType != "") { - // Here, we ensure elementTypeSupportEquals is true + nullSafeCodeGen(ctx, ev, (array1, array2) => { val foundNullElement = ctx.freshName("foundNullElement") - val openHashSet = classOf[OpenHashSet[_]].getName - val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()" - val hs = ctx.freshName("hs") - val arrayData = classOf[ArrayData].getName - val arrays = ctx.freshName("arrays") + val nullElementIndex = ctx.freshName("nullElementIndex") + val builder = ctx.freshName("builder") val array = ctx.freshName("array") + val arrays = ctx.freshName("arrays") val arrayDataIdx = ctx.freshName("arrayDataIdx") + val openHashSet = classOf[OpenHashSet[_]].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" + val hashSet = ctx.freshName("hashSet") + val arrayBuilder = "scala.collection.mutable.ArrayBuilder" + val arrayBuilderClass = s"$arrayBuilder$$of$ptName" + val arrayBuilderClassTag = s"scala.reflect.ClassTag$$.MODULE$$.$ptName()" + + def withArrayNullAssignment(body: String) = + if (dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array.isNullAt($i)) { + | if (!$foundNullElement) { + | $nullElementIndex = $size; + | $foundNullElement = true; + | $size++; + | $builder.$$plus$$eq($nullValueHolder); + | } + |} else { + | $body + |} + """.stripMargin + } else { + body + } + + val processArray = withArrayNullAssignment( + s""" + |$jt $value = ${genGetValue(array, i)}; + |if (!$hashSet.contains($hsValueCast$value)) { + | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | break; + | } + | $hashSet.add$hsPostFix($hsValueCast$value); + | $builder.$$plus$$eq($value); + |} + """.stripMargin) + + // Only need to track null element index when result array's element is nullable. + val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |boolean $foundNullElement = false; + |int $nullElementIndex = -1; + """.stripMargin + } else { + "" + } + s""" - |$openHashSet $hs = new $openHashSet$postFix($classTag); - |boolean $foundNullElement = false; + |$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag); + |$declareNullTrackVariables + |int $size = 0; + |$arrayBuilderClass $builder = + | ($arrayBuilderClass)$arrayBuilder.make($arrayBuilderClassTag); --- End diff -- nit: new `$arrayBuilderClass()` should work?
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org