Github user kiszk commented on a diff in the pull request:
https://github.com/apache/spark/pull/21061#discussion_r182689118
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
---
@@ -505,3 +506,150 @@ case class ArrayMax(child: Expression) extends
UnaryExpression with ImplicitCast
override def prettyName: String = "array_max"
}
+
+abstract class ArraySetUtils extends BinaryExpression with
ExpectsInputTypes {
+ val kindUnion = 1
+ def typeId: Int
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType,
ArrayType)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val r = super.checkInputDataTypes()
+ if ((r == TypeCheckResult.TypeCheckSuccess) &&
+ (left.dataType.asInstanceOf[ArrayType].elementType !=
+ right.dataType.asInstanceOf[ArrayType].elementType)) {
+ TypeCheckResult.TypeCheckFailure("Element type in both arrays must
be the same")
+ } else {
+ r
+ }
+ }
+
+ override def dataType: DataType = left.dataType
+
+ private def elementType = dataType.asInstanceOf[ArrayType].elementType
+ private def cn1 = left.dataType.asInstanceOf[ArrayType].containsNull
+ private def cn2 = right.dataType.asInstanceOf[ArrayType].containsNull
+
+ override def nullSafeEval(input1: Any, input2: Any): Any = {
+ val ary1 = input1.asInstanceOf[ArrayData]
+ val ary2 = input2.asInstanceOf[ArrayData]
+
+ if (!cn1 && !cn2) {
+ elementType match {
+ case IntegerType =>
+ // avoid boxing of primitive int array elements
+ val hs = new OpenHashSet[Int]
+ var i = 0
+ while (i < ary1.numElements()) {
+ hs.add(ary1.getInt(i))
+ i += 1
+ }
+ i = 0
+ while (i < ary2.numElements()) {
+ hs.add(ary2.getInt(i))
+ i += 1
+ }
+ UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray)
+ case LongType =>
+ // avoid boxing of primitive long array elements
+ val hs = new OpenHashSet[Long]
+ var i = 0
+ while (i < ary1.numElements()) {
+ hs.add(ary1.getLong(i))
+ i += 1
+ }
+ i = 0
+ while (i < ary2.numElements()) {
+ hs.add(ary2.getLong(i))
+ i += 1
+ }
+ UnsafeArrayData.fromPrimitiveArray(hs.iterator.toArray)
+ case _ =>
+ val hs = new OpenHashSet[Any]
+ var i = 0
+ while (i < ary1.numElements()) {
+ hs.add(ary1.get(i, elementType))
+ i += 1
+ }
+ i = 0
+ while (i < ary2.numElements()) {
+ hs.add(ary2.get(i, elementType))
+ i += 1
+ }
+ new GenericArrayData(hs.iterator.toArray)
+ }
+ } else {
+ ArraySetUtils.arrayUnion(ary1, ary2, elementType)
+ }
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val hs = ctx.freshName("hs")
+ val i = ctx.freshName("i")
+ val ArraySetUtils =
"org.apache.spark.sql.catalyst.expressions.ArraySetUtils"
+ val genericArrayData = classOf[GenericArrayData].getName
+ val unsafeArrayData = classOf[UnsafeArrayData].getName
+ val openHashSet = classOf[OpenHashSet[_]].getName
+ val et = s"org.apache.spark.sql.types.DataTypes.$elementType"
+ val (postFix, classTag, getter, arrayBuilder, castType) = if (!cn1 &&
!cn2) {
+ val ptName = CodeGenerator.primitiveTypeName(elementType)
+ elementType match {
+ case ByteType | ShortType | IntegerType =>
+ (s"$$mcI$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()",
s"get$ptName($i)",
+ s"$unsafeArrayData.fromPrimitiveArray",
CodeGenerator.javaType(elementType))
+ case LongType =>
+ (s"$$mcJ$$sp", s"scala.reflect.ClassTag$$.MODULE$$.$ptName()",
s"get$ptName($i)",
+ s"$unsafeArrayData.fromPrimitiveArray", "long")
+ case _ =>
+ ("", s"scala.reflect.ClassTag$$.MODULE$$.Object()", s"get($i,
$et)",
+ s"new $genericArrayData", "Object")
+ }
+ } else {
+ ("", "", "", "", "")
+ }
+
+ nullSafeCodeGen(ctx, ev, (ary1, ary2) => {
+ if (classTag != "") {
+ s"""
+ |$openHashSet $hs = new $openHashSet$postFix($classTag);
+ |for (int $i = 0; $i < $ary1.numElements(); $i++) {
+ | $hs.add$postFix($ary1.$getter);
+ |}
+ |for (int $i = 0; $i < $ary2.numElements(); $i++) {
+ | $hs.add$postFix($ary2.$getter);
+ |}
+ |${ev.value} = $arrayBuilder(($castType[])
$hs.iterator().toArray($classTag));
--- End diff --
Ah, great catch. I confirmed there is not `iterator()`, which is
specialized`, in `OpenHashSet$mcI$sp`. I will rewrite this.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]