Github user ueshin commented on a diff in the pull request:
https://github.com/apache/spark/pull/21061#discussion_r192254166
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
---
@@ -1882,3 +1882,311 @@ case class ArrayRepeat(left: Expression, right:
Expression)
}
}
+
+object ArraySetLike {
+ val kindUnion = 1
+
+ private val MAX_ARRAY_LENGTH: Int =
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
+
+ def toArrayDataInt(hs: OpenHashSet[Int]): ArrayData = {
+ val array = new Array[Int](hs.size)
+ var pos = hs.nextPos(0)
+ var i = 0
+ while (pos != OpenHashSet.INVALID_POS) {
+ array(i) = hs.getValue(pos)
+ pos = hs.nextPos(pos + 1)
+ i += 1
+ }
+
+ val numBytes = 4L * array.length
+ val unsafeArraySizeInBytes =
UnsafeArrayData.calculateHeaderPortionInBytes(array.length) +
+
org.apache.spark.unsafe.array.ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
+ // Since UnsafeArrayData.fromPrimitiveArray() uses long[], max
elements * 8 bytes can be used
+ if (unsafeArraySizeInBytes <= Integer.MAX_VALUE * 8) {
+ UnsafeArrayData.fromPrimitiveArray(array)
+ } else {
+ new GenericArrayData(array)
+ }
+ }
+
+ def toArrayDataLong(hs: OpenHashSet[Long]): ArrayData = {
+ val array = new Array[Long](hs.size)
+ var pos = hs.nextPos(0)
+ var i = 0
+ while (pos != OpenHashSet.INVALID_POS) {
+ array(i) = hs.getValue(pos)
+ pos = hs.nextPos(pos + 1)
+ i += 1
+ }
+
+ val numBytes = 8L * array.length
+ val unsafeArraySizeInBytes =
UnsafeArrayData.calculateHeaderPortionInBytes(array.length) +
+
org.apache.spark.unsafe.array.ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
+ // Since UnsafeArrayData.fromPrimitiveArray() uses long[], max
elements * 8 bytes can be used
+ if (unsafeArraySizeInBytes <= Integer.MAX_VALUE * 8) {
+ UnsafeArrayData.fromPrimitiveArray(array)
+ } else {
+ new GenericArrayData(array)
+ }
+ }
+
+ def arrayUnion(
+ array1: ArrayData,
+ array2: ArrayData,
+ et: DataType,
+ ordering: Ordering[Any]): ArrayData = {
+ if (ordering == null) {
+ new
GenericArrayData(array1.toObjectArray(et).union(array2.toObjectArray(et))
+ .distinct.asInstanceOf[Array[Any]])
+ } else {
+ val length = math.min(array1.numElements().toLong +
array2.numElements().toLong,
+ ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH)
+ val array = new Array[Any](length.toInt)
+ var pos = 0
+ var hasNull = false
+ Seq(array1, array2).foreach(_.foreach(et, (_, v) => {
+ var found = false
+ if (v == null) {
+ if (hasNull) {
+ found = true
+ } else {
+ hasNull = true
+ }
+ } else {
+ var j = 0
+ while (!found && j < pos) {
+ val va = array(j)
+ if (va != null && ordering.equiv(va, v)) {
+ found = true
+ }
+ j = j + 1
+ }
+ }
+ if (!found) {
+ if (pos > MAX_ARRAY_LENGTH) {
+ throw new RuntimeException(s"Unsuccessful try to union arrays
with $pos" +
+ s" elements due to exceeding the array size limit
$MAX_ARRAY_LENGTH.")
+ }
+ array(pos) = v
+ pos = pos + 1
+ }
+ }))
+ new GenericArrayData(array.slice(0, pos))
+ }
+ }
+}
+
+abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast {
+ def typeId: Int
+
+ override def dataType: DataType = left.dataType
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val typeCheckResult = super.checkInputDataTypes()
+ if (typeCheckResult.isSuccess) {
+
TypeUtils.checkForOrderingExpr(dataType.asInstanceOf[ArrayType].elementType,
+ s"function $prettyName")
+ } else {
+ typeCheckResult
+ }
+ }
+
+ private def cn = left.dataType.asInstanceOf[ArrayType].containsNull ||
+ right.dataType.asInstanceOf[ArrayType].containsNull
+
+ @transient private lazy val ordering: Ordering[Any] =
+ TypeUtils.getInterpretedOrdering(elementType)
+
+ @transient private lazy val elementTypeSupportEquals = elementType match
{
+ case BinaryType => false
+ case _: AtomicType => true
+ case _ => false
+ }
+
+ def intEval(ary: ArrayData, hs2: OpenHashSet[Int]): OpenHashSet[Int]
+ def longEval(ary: ArrayData, hs2: OpenHashSet[Long]): OpenHashSet[Long]
+ def genericEval(ary: ArrayData, hs2: OpenHashSet[Any], et: DataType):
OpenHashSet[Any]
+ def codeGen(ctx: CodegenContext, hs2: String, hs: String, len: String,
getter: String, i: String,
+ postFix: String, newOpenHashSet: String): String
+
+ override def nullSafeEval(input1: Any, input2: Any): Any = {
+ val ary1 = input1.asInstanceOf[ArrayData]
+ val ary2 = input2.asInstanceOf[ArrayData]
+
+ if (!cn) {
+ elementType match {
+ case IntegerType =>
+ // avoid boxing of primitive int array elements
+ val hs2 = new OpenHashSet[Int]
--- End diff --
nit: why `hs2`?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]