Github user kiszk commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21061#discussion_r192340073
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 ---
    @@ -1882,3 +1882,311 @@ case class ArrayRepeat(left: Expression, right: 
Expression)
       }
     
     }
    +
    +object ArraySetLike {
    +  val kindUnion = 1
    +
    +  private val MAX_ARRAY_LENGTH: Int = 
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
    +
    +  def toArrayDataInt(hs: OpenHashSet[Int]): ArrayData = {
    +    val array = new Array[Int](hs.size)
    +    var pos = hs.nextPos(0)
    +    var i = 0
    +    while (pos != OpenHashSet.INVALID_POS) {
    +      array(i) = hs.getValue(pos)
    +      pos = hs.nextPos(pos + 1)
    +      i += 1
    +    }
    +
    +    val numBytes = 4L * array.length
    +    val unsafeArraySizeInBytes = 
UnsafeArrayData.calculateHeaderPortionInBytes(array.length) +
    +      
org.apache.spark.unsafe.array.ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
    +    // Since UnsafeArrayData.fromPrimitiveArray() uses long[], max 
elements * 8 bytes can be used
    +    if (unsafeArraySizeInBytes <= Integer.MAX_VALUE * 8) {
    +      UnsafeArrayData.fromPrimitiveArray(array)
    +    } else {
    +      new GenericArrayData(array)
    +    }
    +  }
    +
    +  def toArrayDataLong(hs: OpenHashSet[Long]): ArrayData = {
    +    val array = new Array[Long](hs.size)
    +    var pos = hs.nextPos(0)
    +    var i = 0
    +    while (pos != OpenHashSet.INVALID_POS) {
    +      array(i) = hs.getValue(pos)
    +      pos = hs.nextPos(pos + 1)
    +      i += 1
    +    }
    +
    +    val numBytes = 8L * array.length
    +    val unsafeArraySizeInBytes = 
UnsafeArrayData.calculateHeaderPortionInBytes(array.length) +
    +      
org.apache.spark.unsafe.array.ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
    +    // Since UnsafeArrayData.fromPrimitiveArray() uses long[], max 
elements * 8 bytes can be used
    +    if (unsafeArraySizeInBytes <= Integer.MAX_VALUE * 8) {
    +      UnsafeArrayData.fromPrimitiveArray(array)
    +    } else {
    +      new GenericArrayData(array)
    +    }
    +  }
    +
    +  def arrayUnion(
    +      array1: ArrayData,
    +      array2: ArrayData,
    +      et: DataType,
    +      ordering: Ordering[Any]): ArrayData = {
    +    if (ordering == null) {
    +      new 
GenericArrayData(array1.toObjectArray(et).union(array2.toObjectArray(et))
    +        .distinct.asInstanceOf[Array[Any]])
    +    } else {
    +      val length = math.min(array1.numElements().toLong + 
array2.numElements().toLong,
    +        ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH)
    +      val array = new Array[Any](length.toInt)
    +      var pos = 0
    +      var hasNull = false
    +      Seq(array1, array2).foreach(_.foreach(et, (_, v) => {
    +        var found = false
    +        if (v == null) {
    +          if (hasNull) {
    +            found = true
    +          } else {
    +            hasNull = true
    +          }
    +        } else {
    +          var j = 0
    +          while (!found && j < pos) {
    +            val va = array(j)
    +            if (va != null && ordering.equiv(va, v)) {
    +              found = true
    +            }
    +            j = j + 1
    +          }
    +        }
    +        if (!found) {
    +          if (pos > MAX_ARRAY_LENGTH) {
    +            throw new RuntimeException(s"Unsuccessful try to union arrays 
with $pos" +
    +              s" elements due to exceeding the array size limit 
$MAX_ARRAY_LENGTH.")
    +          }
    +          array(pos) = v
    +          pos = pos + 1
    +        }
    +      }))
    +      new GenericArrayData(array.slice(0, pos))
    +    }
    +  }
    +}
    +
    +abstract class ArraySetLike extends BinaryArrayExpressionWithImplicitCast {
    +  def typeId: Int
    +
    +  override def dataType: DataType = left.dataType
    +
    +  override def checkInputDataTypes(): TypeCheckResult = {
    +    val typeCheckResult = super.checkInputDataTypes()
    +    if (typeCheckResult.isSuccess) {
    +      
TypeUtils.checkForOrderingExpr(dataType.asInstanceOf[ArrayType].elementType,
    +        s"function $prettyName")
    +    } else {
    +      typeCheckResult
    +    }
    +  }
    +
    +  private def cn = left.dataType.asInstanceOf[ArrayType].containsNull ||
    +    right.dataType.asInstanceOf[ArrayType].containsNull
    +
    +  @transient private lazy val ordering: Ordering[Any] =
    +    TypeUtils.getInterpretedOrdering(elementType)
    +
    +  @transient private lazy val elementTypeSupportEquals = elementType match 
{
    +    case BinaryType => false
    +    case _: AtomicType => true
    +    case _ => false
    +  }
    +
    +  def intEval(ary: ArrayData, hs2: OpenHashSet[Int]): OpenHashSet[Int]
    +  def longEval(ary: ArrayData, hs2: OpenHashSet[Long]): OpenHashSet[Long]
    +  def genericEval(ary: ArrayData, hs2: OpenHashSet[Any], et: DataType): 
OpenHashSet[Any]
    +  def codeGen(ctx: CodegenContext, hs2: String, hs: String, len: String, 
getter: String, i: String,
    +    postFix: String, newOpenHashSet: String): String
    +
    +  override def nullSafeEval(input1: Any, input2: Any): Any = {
    +    val ary1 = input1.asInstanceOf[ArrayData]
    +    val ary2 = input2.asInstanceOf[ArrayData]
    +
    +    if (!cn) {
    +      elementType match {
    +        case IntegerType =>
    +          // avoid boxing of primitive int array elements
    +          val hs2 = new OpenHashSet[Int]
    +          var i = 0
    +          while (i < ary2.numElements()) {
    +            hs2.add(ary2.getInt(i))
    +            i += 1
    +          }
    +          ArraySetLike.toArrayDataInt(intEval(ary1, hs2))
    +        case LongType =>
    +          // avoid boxing of primitive long array elements
    +          val hs2 = new OpenHashSet[Long]
    +          var i = 0
    +          while (i < ary2.numElements()) {
    +            hs2.add(ary2.getLong(i))
    +            i += 1
    +          }
    +          ArraySetLike.toArrayDataLong(longEval(ary1, hs2))
    +        case _ =>
    +          val hs2 = new OpenHashSet[Any]
    +          var i = 0
    +          while (i < ary2.numElements()) {
    +            hs2.add(ary2.get(i, elementType))
    +            i += 1
    +          }
    +          new GenericArrayData(genericEval(ary1, hs2, 
elementType).iterator.toArray)
    +      }
    +    } else {
    +      if (typeId == ArraySetLike.kindUnion) {
    --- End diff --
    
    I eliminated `typeId` field by defining a method name in the concrete class.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to