Github user kiszk commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21061#discussion_r182689118
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 ---
    @@ -505,3 +506,150 @@ case class ArrayMax(child: Expression) extends 
UnaryExpression with ImplicitCast
     
       override def prettyName: String = "array_max"
     }
    +
    +abstract class ArraySetUtils extends BinaryExpression with 
ExpectsInputTypes {
    +  val kindUnion = 1
    +  def typeId: Int
    +
    +  override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, 
ArrayType)
    +
    +  override def checkInputDataTypes(): TypeCheckResult = {
    +    val r = super.checkInputDataTypes()
    +    if ((r == TypeCheckResult.TypeCheckSuccess) &&
    +      (left.dataType.asInstanceOf[ArrayType].elementType !=
    +        right.dataType.asInstanceOf[ArrayType].elementType)) {
    +      TypeCheckResult.TypeCheckFailure("Element type in both arrays must 
be the same")
    +    } else {
    +      r
    +    }
    +  }
    +
    +  override def dataType: DataType = left.dataType
    +
    +  private def elementType = dataType.asInstanceOf[ArrayType].elementType
    +  private def cn1 = left.dataType.asInstanceOf[ArrayType].containsNull
    +  private def cn2 = right.dataType.asInstanceOf[ArrayType].containsNull
    +
    +  override def nullSafeEval(input1: Any, input2: Any): Any = {
    +    val ary1 = input1.asInstanceOf[ArrayData]
    +    val ary2 = input2.asInstanceOf[ArrayData]
    +
    +    if (!cn1 && !cn2) {
    +      elementType match {
    +        case IntegerType =>
    +          // avoid boxing of primitive int array elements
    +          val hs = new OpenHashSet[Int]
    +          var i = 0
    +          while (i < ary1.numElements()) {
    +            hs.add(ary1.getInt(i))
    +            i += 1
    +          }
    +          i = 0
    +          while (i < ary2.numElements()) {
    +            hs.add(ary2.getInt(i))
    +            i += 1
    +          }
    +          UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray)
    +        case LongType =>
    +          // avoid boxing of primitive long array elements
    +          val hs = new OpenHashSet[Long]
    +          var i = 0
    +          while (i < ary1.numElements()) {
    +            hs.add(ary1.getLong(i))
    +            i += 1
    +          }
    +          i = 0
    +          while (i < ary2.numElements()) {
    +            hs.add(ary2.getLong(i))
    +            i += 1
    +          }
    +          UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray)
    +        case _ =>
    +          val hs = new OpenHashSet[Any]
    +          var i = 0
    +          while (i < ary1.numElements()) {
    +            hs.add(ary1.get(i, elementType))
    +            i += 1
    +          }
    +          i = 0
    +          while (i < ary2.numElements()) {
    +            hs.add(ary2.get(i, elementType))
    +            i += 1
    +          }
    +          new GenericArrayData(hs.iterator.toArray)
    +      }
    +    } else {
    +      ArraySetUtils.arrayUnion(ary1, ary2, elementType)
    +    }
    +  }
    +
    +  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    +    val hs = ctx.freshName("hs")
    +    val i = ctx.freshName("i")
    +    val ArraySetUtils = 
"org.apache.spark.sql.catalyst.expressions.ArraySetUtils"
    +    val genericArrayData = classOf[GenericArrayData].getName
    +    val unsafeArrayData = classOf[UnsafeArrayData].getName
    +    val openHashSet = classOf[OpenHashSet[_]].getName
    +    val et = s"org.apache.spark.sql.types.DataTypes.$elementType"
    +    val (postFix, classTag, getter, arrayBuilder, castType) = if (!cn1 && 
!cn2) {
    +      val ptName = CodeGenerator.primitiveTypeName(elementType)
    +      elementType match {
    +        case ByteType | ShortType | IntegerType =>
    +          (s"$$mcI$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", 
s"get$ptName($i)",
    +            s"$unsafeArrayData.fromPrimitiveArray", 
CodeGenerator.javaType(elementType))
    +        case LongType =>
    +          (s"$$mcJ$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()", 
s"get$ptName($i)",
    +            s"$unsafeArrayData.fromPrimitiveArray", "long")
    +        case _ =>
    +          ("", s"scala.reflect.ClassTag$$.MODULE$$.Object()", s"get($i, 
$et)",
    +            s"new $genericArrayData", "Object")
    +      }
    +    } else {
    +      ("", "", "", "", "")
    +    }
    +
    +    nullSafeCodeGen(ctx, ev, (ary1, ary2) => {
    +      if (classTag != "") {
    +        s"""
    +           |$openHashSet $hs = new $openHashSet$postFix($classTag);
    +           |for (int $i = 0; $i < $ary1.numElements(); $i++) {
    +           |  $hs.add$postFix($ary1.$getter);
    +           |}
    +           |for (int $i = 0; $i < $ary2.numElements(); $i++) {
    +           |  $hs.add$postFix($ary2.$getter);
    +           |}
    +           |${ev.value} = $arrayBuilder(($castType[]) 
$hs.iterator().toArray($classTag));
    --- End diff --
    
    Ah, great catch. I confirmed there is not `iterator()`, which is 
specialized`, in `OpenHashSet$mcI$sp`. I will rewrite this.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to