Github user viirya commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21061#discussion_r201620967
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 ---
    @@ -3261,3 +3261,322 @@ case class ArrayDistinct(child: Expression)
     
       override def prettyName: String = "array_distinct"
     }
    +
    +/**
    + * Will become common base class for [[ArrayUnion]], ArrayIntersect, and 
ArrayExcept.
    + */
    +abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast {
    +  override def dataType: DataType = {
    +    val dataTypes = children.map(_.dataType.asInstanceOf[ArrayType])
    +    ArrayType(elementType, dataTypes.exists(_.containsNull))
    +  }
    +
    +  override def checkInputDataTypes(): TypeCheckResult = {
    +    val typeCheckResult = super.checkInputDataTypes()
    +    if (typeCheckResult.isSuccess) {
    +      
TypeUtils.checkForOrderingExpr(dataType.asInstanceOf[ArrayType].elementType,
    +        s"function $prettyName")
    +    } else {
    +      typeCheckResult
    +    }
    +  }
    +
    +  @transient protected lazy val ordering: Ordering[Any] =
    +    TypeUtils.getInterpretedOrdering(elementType)
    +
    +  @transient protected lazy val elementTypeSupportEquals = elementType 
match {
    +    case BinaryType => false
    +    case _: AtomicType => true
    +    case _ => false
    +  }
    +}
    +
    +object ArraySetLike {
    +  def throwUnionLengthOverflowException(length: Int): Unit = {
    +    throw new RuntimeException(s"Unsuccessful try to union arrays with 
$length " +
    +      s"elements due to exceeding the array size limit " +
    +      s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
    +  }
    +}
    +
    +
    +/**
    + * Returns an array of the elements in the union of x and y, without 
duplicates
    + */
    +@ExpressionDescription(
    +  usage = """
    +    _FUNC_(array1, array2) - Returns an array of the elements in the union 
of array1 and array2,
    +      without duplicates.
    +  """,
    +  examples = """
    +    Examples:
    +      > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5));
    +       array(1, 2, 3, 5)
    +  """,
    +  since = "2.4.0")
    +case class ArrayUnion(left: Expression, right: Expression) extends 
ArraySetLike {
    +  var hsInt: OpenHashSet[Int] = _
    +  var hsLong: OpenHashSet[Long] = _
    +
    +  def assignInt(array: ArrayData, idx: Int, resultArray: ArrayData, pos: 
Int): Boolean = {
    +    val elem = array.getInt(idx)
    +    if (!hsInt.contains(elem)) {
    +      if (resultArray != null) {
    +        resultArray.setInt(pos, elem)
    +      }
    +      hsInt.add(elem)
    +      true
    +    } else {
    +      false
    +    }
    +  }
    +
    +  def assignLong(array: ArrayData, idx: Int, resultArray: ArrayData, pos: 
Int): Boolean = {
    +    val elem = array.getLong(idx)
    +    if (!hsLong.contains(elem)) {
    +      if (resultArray != null) {
    +        resultArray.setLong(pos, elem)
    +      }
    +      hsLong.add(elem)
    +      true
    +    } else {
    +      false
    +    }
    +  }
    +
    +  def evalIntLongPrimitiveType(
    +      array1: ArrayData,
    +      array2: ArrayData,
    +      resultArray: ArrayData,
    +      isLongType: Boolean): Int = {
    +    // store elements into resultArray
    +    var nullElementSize = 0
    +    var pos = 0
    +    Seq(array1, array2).foreach { array =>
    +      var i = 0
    +      while (i < array.numElements()) {
    +        val size = if (!isLongType) hsInt.size else hsLong.size
    +        if (size + nullElementSize > 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
    +          ArraySetLike.throwUnionLengthOverflowException(size)
    +        }
    +        if (array.isNullAt(i)) {
    +          if (nullElementSize == 0) {
    +            if (resultArray != null) {
    +              resultArray.setNullAt(pos)
    +            }
    +            pos += 1
    +            nullElementSize = 1
    +          }
    +        } else {
    +          val assigned = if (!isLongType) {
    +            assignInt(array, i, resultArray, pos)
    +          } else {
    +            assignLong(array, i, resultArray, pos)
    +          }
    +          if (assigned) {
    +            pos += 1
    +          }
    +        }
    +        i += 1
    +      }
    +    }
    +    pos
    +  }
    +
    +  override def nullSafeEval(input1: Any, input2: Any): Any = {
    +    val array1 = input1.asInstanceOf[ArrayData]
    +    val array2 = input2.asInstanceOf[ArrayData]
    +
    +    if (elementTypeSupportEquals) {
    +      elementType match {
    +        case IntegerType =>
    +          // avoid boxing of primitive int array elements
    +          // calculate result array size
    +          hsInt = new OpenHashSet[Int]
    +          val elements = evalIntLongPrimitiveType(array1, array2, null, 
false)
    +          hsInt = new OpenHashSet[Int]
    +          val resultArray = if (UnsafeArrayData.canUseGenericArrayData(
    +            IntegerType.defaultSize, elements)) {
    +            new GenericArrayData(new Array[Any](elements))
    +          } else {
    +            UnsafeArrayData.forPrimitiveArray(
    +              Platform.INT_ARRAY_OFFSET, elements, IntegerType.defaultSize)
    +          }
    +          evalIntLongPrimitiveType(array1, array2, resultArray, false)
    +          resultArray
    +        case LongType =>
    +          // avoid boxing of primitive long array elements
    +          // calculate result array size
    +          hsLong = new OpenHashSet[Long]
    +          val elements = evalIntLongPrimitiveType(array1, array2, null, 
true)
    +          hsLong = new OpenHashSet[Long]
    +          val resultArray = if (UnsafeArrayData.canUseGenericArrayData(
    +            LongType.defaultSize, elements)) {
    +            new GenericArrayData(new Array[Any](elements))
    +          } else {
    +            UnsafeArrayData.forPrimitiveArray(
    +              Platform.LONG_ARRAY_OFFSET, elements, LongType.defaultSize)
    +          }
    +          evalIntLongPrimitiveType(array1, array2, resultArray, true)
    +          resultArray
    +        case _ =>
    +          val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
    +          val hs = new OpenHashSet[Any]
    +          var foundNullElement = false
    +          Seq(array1, array2).foreach { array =>
    +            var i = 0
    +            while (i < array.numElements()) {
    +              if (array.isNullAt(i)) {
    +                if (!foundNullElement) {
    +                  arrayBuffer += null
    +                  foundNullElement = true
    +                }
    +              } else {
    +                val elem = array.get(i, elementType)
    +                if (!hs.contains(elem)) {
    +                  if (arrayBuffer.size > 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
    +                    
ArraySetLike.throwUnionLengthOverflowException(arrayBuffer.size)
    +                  }
    +                  arrayBuffer += elem
    +                  hs.add(elem)
    +                }
    +              }
    +              i += 1
    +            }
    +          }
    +          new GenericArrayData(arrayBuffer)
    +      }
    +    } else {
    +      ArrayUnion.unionOrdering(array1, array2, elementType, ordering)
    +    }
    +  }
    +
    +  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    +    val i = ctx.freshName("i")
    +    val pos = ctx.freshName("pos")
    +    val value = ctx.freshName("value")
    +    val size = ctx.freshName("size")
    +    val (postFix, openHashElementType, getter, setter, javaTypeName, 
castOp, arrayBuilder) =
    +      if (elementTypeSupportEquals) {
    +        elementType match {
    +          case ByteType | ShortType | IntegerType | LongType =>
    +            val ptName = CodeGenerator.primitiveTypeName(elementType)
    +            val unsafeArray = ctx.freshName("unsafeArray")
    +            (if (elementType == LongType) s"$$mcJ$$sp" else s"$$mcI$$sp",
    +              if (elementType == LongType) "Long" else "Int",
    +              s"get$ptName($i)", s"set$ptName($pos, $value)", 
CodeGenerator.javaType(elementType),
    +              if (elementType == LongType) "(long)" else "(int)",
    +              s"""
    +                 |${ctx.createUnsafeArray(unsafeArray, size, elementType, 
s" $prettyName failed.")}
    --- End diff --
    
    Looks like we don't automatically choose to use `GenericArrayData` as the 
same as interpreted path?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to