Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/21061#discussion_r192254166 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -1882,3 +1882,311 @@ case class ArrayRepeat(left: Expression, right: Expression) } } + +object ArraySetLike { + val kindUnion = 1 + + private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + def toArrayDataInt(hs: OpenHashSet[Int]): ArrayData = { + val array = new Array[Int](hs.size) + var pos = hs.nextPos(0) + var i = 0 + while (pos != OpenHashSet.INVALID_POS) { + array(i) = hs.getValue(pos) + pos = hs.nextPos(pos + 1) + i += 1 + } + + val numBytes = 4L * array.length + val unsafeArraySizeInBytes = UnsafeArrayData.calculateHeaderPortionInBytes(array.length) + + org.apache.spark.unsafe.array.ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) + // Since UnsafeArrayData.fromPrimitiveArray() uses long[], max elements * 8 bytes can be used + if (unsafeArraySizeInBytes <= Integer.MAX_VALUE * 8) { + UnsafeArrayData.fromPrimitiveArray(array) + } else { + new GenericArrayData(array) + } + } + + def toArrayDataLong(hs: OpenHashSet[Long]): ArrayData = { + val array = new Array[Long](hs.size) + var pos = hs.nextPos(0) + var i = 0 + while (pos != OpenHashSet.INVALID_POS) { + array(i) = hs.getValue(pos) + pos = hs.nextPos(pos + 1) + i += 1 + } + + val numBytes = 8L * array.length + val unsafeArraySizeInBytes = UnsafeArrayData.calculateHeaderPortionInBytes(array.length) + + org.apache.spark.unsafe.array.ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) + // Since UnsafeArrayData.fromPrimitiveArray() uses long[], max elements * 8 bytes can be used + if (unsafeArraySizeInBytes <= Integer.MAX_VALUE * 8) { + UnsafeArrayData.fromPrimitiveArray(array) + } else { + new GenericArrayData(array) + } + } + + def arrayUnion( + array1: ArrayData, + array2: ArrayData, + et: DataType, + ordering: Ordering[Any]): ArrayData = { + if (ordering == null) { + new GenericArrayData(array1.toObjectArray(et).union(array2.toObjectArray(et)) + .distinct.asInstanceOf[Array[Any]]) + } else { + val length = math.min(array1.numElements().toLong + array2.numElements().toLong, + ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) + val array = new Array[Any](length.toInt) + var pos = 0 + var hasNull = false + Seq(array1, array2).foreach(_.foreach(et, (_, v) => { + var found = false + if (v == null) { + if (hasNull) { + found = true + } else { + hasNull = true + } + } else { + var j = 0 + while (!found && j < pos) { + val va = array(j) + if (va != null && ordering.equiv(va, v)) { + found = true + } + j = j + 1 + } + } + if (!found) { + if (pos > MAX_ARRAY_LENGTH) { + throw new RuntimeException(s"Unsuccessful try to union arrays with $pos" + + s" elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.") + } + array(pos) = v + pos = pos + 1 + } + })) + new GenericArrayData(array.slice(0, pos)) + } + } +} + +abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast { + def typeId: Int + + override def dataType: DataType = left.dataType + + override def checkInputDataTypes(): TypeCheckResult = { + val typeCheckResult = super.checkInputDataTypes() + if (typeCheckResult.isSuccess) { + TypeUtils.checkForOrderingExpr(dataType.asInstanceOf[ArrayType].elementType, + s"function $prettyName") + } else { + typeCheckResult + } + } + + private def cn = left.dataType.asInstanceOf[ArrayType].containsNull || + right.dataType.asInstanceOf[ArrayType].containsNull + + @transient private lazy val ordering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(elementType) + + @transient private lazy val elementTypeSupportEquals = elementType match { + case BinaryType => false + case _: AtomicType => true + case _ => false + } + + def intEval(ary: ArrayData, hs2: OpenHashSet[Int]): OpenHashSet[Int] + def longEval(ary: ArrayData, hs2: OpenHashSet[Long]): OpenHashSet[Long] + def genericEval(ary: ArrayData, hs2: OpenHashSet[Any], et: DataType): OpenHashSet[Any] + def codeGen(ctx: CodegenContext, hs2: String, hs: String, len: String, getter: String, i: String, + postFix: String, newOpenHashSet: String): String + + override def nullSafeEval(input1: Any, input2: Any): Any = { + val ary1 = input1.asInstanceOf[ArrayData] + val ary2 = input2.asInstanceOf[ArrayData] + + if (!cn) { + elementType match { + case IntegerType => + // avoid boxing of primitive int array elements + val hs2 = new OpenHashSet[Int] --- End diff -- nit: why `hs2`?
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org