Github user ueshin commented on a diff in the pull request:
https://github.com/apache/spark/pull/21061#discussion_r194530787
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
---
@@ -2189,3 +2189,293 @@ case class ArrayRemove(left: Expression, right:
Expression)
override def prettyName: String = "array_remove"
}
+
+object ArraySetLike {
+ def useGenericArrayData(elementSize: Int, length: Int): Boolean = {
+ // Use the same calculation in UnsafeArrayData.fromPrimitiveArray()
+ val headerInBytes =
UnsafeArrayData.calculateHeaderPortionInBytes(length)
+ val valueRegionInBytes = elementSize.toLong * length
+ val totalSizeInLongs = (headerInBytes + valueRegionInBytes + 7) / 8
+ totalSizeInLongs > Integer.MAX_VALUE / 8
+ }
+
+ def throwUnionLengthOverflowException(length: Int): Unit = {
+ throw new RuntimeException(s"Unsuccessful try to union arrays with
${length}" +
+ s"elements due to exceeding the array size limit " +
+ s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
+ }
+
+ def evalUnionContainsNull(
+ array1: ArrayData,
+ array2: ArrayData,
+ elementType: DataType,
+ ordering: Ordering[Any]): ArrayData = {
+ if (ordering == null) {
+ val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
+ val hs = new mutable.HashSet[Any]
+ Seq(array1, array2).foreach(array => {
+ var i = 0
+ while (i < array.numElements()) {
+ val elem = array.get(i, elementType)
+ if (hs.add(elem)) {
+ if (arrayBuffer.length >
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+ throwUnionLengthOverflowException(arrayBuffer.length)
+ }
+ arrayBuffer += elem
+ }
+ i += 1
+ }
+ })
+ new GenericArrayData(arrayBuffer)
+ } else {
+ val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
+ var alreadyIncludeNull = false
+ Seq(array1, array2).foreach(_.foreach(elementType, (_, elem) => {
+ var found = false
+ if (elem == null) {
+ if (alreadyIncludeNull) {
+ found = true
+ } else {
+ alreadyIncludeNull = true
+ }
+ } else {
+ // check elem is already stored in arrayBuffer or not?
+ var j = 0
+ while (!found && j < arrayBuffer.size) {
+ val va = arrayBuffer(j)
+ if (va != null && ordering.equiv(va, elem)) {
+ found = true
+ }
+ j = j + 1
+ }
+ }
+ if (!found) {
+ if (arrayBuffer.length >
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+ throwUnionLengthOverflowException(arrayBuffer.length)
+ }
+ arrayBuffer += elem
+ }
+ }))
+ new GenericArrayData(arrayBuffer)
+ }
+ }
+}
+
+
+abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast {
+ override def dataType: DataType = left.dataType
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val typeCheckResult = super.checkInputDataTypes()
+ if (typeCheckResult.isSuccess) {
+
TypeUtils.checkForOrderingExpr(dataType.asInstanceOf[ArrayType].elementType,
+ s"function $prettyName")
+ } else {
+ typeCheckResult
+ }
+ }
+
+ protected def cn = left.dataType.asInstanceOf[ArrayType].containsNull ||
+ right.dataType.asInstanceOf[ArrayType].containsNull
+
+ @transient protected lazy val ordering: Ordering[Any] =
+ TypeUtils.getInterpretedOrdering(elementType)
+
+ @transient protected lazy val elementTypeSupportEquals = elementType
match {
+ case BinaryType => false
+ case _: AtomicType => true
+ case _ => false
+ }
+}
+
+/**
+ * Returns an array of the elements in the union of x and y, without
duplicates
+ */
+@ExpressionDescription(
+ usage = """
+ _FUNC_(array1, array2) - Returns an array of the elements in the union
of array1 and array2,
+ without duplicates.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5));
+ array(1, 2, 3, 5)
+ """,
+ since = "2.4.0")
+case class ArrayUnion(left: Expression, right: Expression) extends
ArraySetLike {
+
+ override def nullSafeEval(input1: Any, input2: Any): Any = {
+ val array1 = input1.asInstanceOf[ArrayData]
+ val array2 = input2.asInstanceOf[ArrayData]
+
+ if (!cn) {
+ elementType match {
+ case IntegerType =>
+ // avoid boxing of primitive int array elements
+ // calculate result array size
+ val hsSize = new OpenHashSet[Int]
+ Seq(array1, array2).foreach(array => {
+ var i = 0
+ while (i < array.numElements()) {
+ if (hsSize.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH)
{
+ ArraySetLike.throwUnionLengthOverflowException(hsSize.size)
+ }
+ hsSize.add(array.getInt(i))
+ i += 1
+ }
+ })
+ // store elements into array
+ val resultArray = new Array[Int](hsSize.size)
+ val hs = new OpenHashSet[Int]
+ var pos = 0
+ Seq(array1, array2).foreach(array => {
+ var i = 0
+ while (i < array.numElements () ) {
+ val elem = array.getInt (i)
+ if (!hs.contains (elem) ) {
+ resultArray (pos) = elem
+ hs.add (elem)
+ pos += 1
+ }
+ i += 1
+ }
+ })
+ if (ArraySetLike.useGenericArrayData(IntegerType.defaultSize,
resultArray.length)) {
+ new GenericArrayData(resultArray)
+ } else {
+ UnsafeArrayData.fromPrimitiveArray(resultArray)
+ }
+ case LongType =>
+ // avoid boxing of primitive long array elements
+ // calculate result array size
+ val hsSize = new OpenHashSet[Long]
+ Seq(array1, array2).foreach(array => {
+ var i = 0
+ while (i < array.numElements()) {
+ if (hsSize.size > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH)
{
+ ArraySetLike.throwUnionLengthOverflowException(hsSize.size)
+ }
+ hsSize.add(array.getLong(i))
+ i += 1
+ }
+ })
+ // store elements into array
+ val resultArray = new Array[Long](hsSize.size)
+ val hs = new OpenHashSet[Long]
+ var pos = 0
+ Seq(array1, array2).foreach(array => {
+ var i = 0
+ while (i < array.numElements()) {
+ val elem = array.getLong(i)
+ if (!hs.contains(elem)) {
+ resultArray(pos) = elem
+ hs.add(elem)
+ pos += 1
+ }
+ i += 1
+ }
+ })
+ if (ArraySetLike.useGenericArrayData(LongType.defaultSize,
resultArray.length)) {
+ new GenericArrayData(resultArray)
+ } else {
+ UnsafeArrayData.fromPrimitiveArray(resultArray)
+ }
+ case _ =>
+ val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
+ val hs = new OpenHashSet[Any]
+ Seq(array1, array2).foreach(array => {
+ var i = 0
+ while (i < array.numElements()) {
+ val elem = array.get(i, elementType)
+ if (!hs.contains(elem)) {
+ if (arrayBuffer.size >
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
+
ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size)
+ }
+ arrayBuffer += elem
+ hs.add(elem)
+ }
+ i += 1
+ }
+ })
+ new GenericArrayData(arrayBuffer)
+ }
+ } else {
+ ArraySetLike.evalUnionContainsNull(array1, array2, elementType,
+ if (elementTypeSupportEquals) null else ordering)
+ }
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val i = ctx.freshName("i")
+ val pos = ctx.freshName("pos")
+ val value = ctx.freshName("value")
+ val size = ctx.freshName("size")
+ val genericArrayData = classOf[GenericArrayData].getName
+ val (postFix, classTag, getter, setter, javaTypeName, arrayBuilder) =
if (!cn) {
--- End diff --
We need to switch the codegen based on `elementTypeSupportEquals`?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]