Github user ueshin commented on a diff in the pull request: https://github.com/apache/spark/pull/21102#discussion_r207766511 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -3965,6 +4034,242 @@ object ArrayUnion { } } +/** + * Returns an array of the elements in the intersect of x and y, without duplicates + */ +@ExpressionDescription( + usage = """ + _FUNC_(array1, array2) - Returns an array of the elements in the intersection of array1 and + array2, without duplicates. + """, + examples = """ + Examples:Fun + > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); + array(1, 3) + """, + since = "2.4.0") +case class ArrayIntersect(left: Expression, right: Expression) extends ArraySetLike + with ComplexTypeMergingExpression { + override def dataType: DataType = { + dataTypeCheck + ArrayType(elementType, + left.dataType.asInstanceOf[ArrayType].containsNull && + right.dataType.asInstanceOf[ArrayType].containsNull) + } + + @transient lazy val evalIntersect: (ArrayData, ArrayData) => ArrayData = { + if (elementTypeSupportEquals) { + (array1, array2) => + val hs = new OpenHashSet[Any] + val hsResult = new OpenHashSet[Any] + var foundNullElement = false + var i = 0 + while (i < array2.numElements()) { + if (array2.isNullAt(i)) { + foundNullElement = true + } else { + val elem = array2.get(i, elementType) + hs.add(elem) + } + i += 1 + } + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + i = 0 + while (i < array1.numElements()) { + if (array1.isNullAt(i)) { + if (foundNullElement) { + arrayBuffer += null + foundNullElement = false + } + } else { + val elem = array1.get(i, elementType) + if (hs.contains(elem) && !hsResult.contains(elem)) { + arrayBuffer += elem + hsResult.add(elem) + } + } + i += 1 + } + new GenericArrayData(arrayBuffer) + } else { + (array1, array2) => + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var alreadySeenNull = false + var i = 0 + while (i < array1.numElements()) { + var found = false + val elem1 = array1.get(i, elementType) + if (array1.isNullAt(i)) { + if (!alreadySeenNull) { + var j = 0 + while (!found && j < array2.numElements()) { + found = array2.isNullAt(j) + j += 1 + } + // array2 is scanned only once for null element + alreadySeenNull = true + } + } else { + var j = 0 + while (!found && j < array2.numElements()) { + if (!array2.isNullAt(j)) { + val elem2 = array2.get(j, elementType) + if (ordering.equiv(elem1, elem2)) { + // check whether elem1 is already stored in arrayBuffer + var foundArrayBuffer = false + var k = 0 + while (!foundArrayBuffer && k < arrayBuffer.size) { + val va = arrayBuffer(k) + foundArrayBuffer = (va != null) && ordering.equiv(va, elem1) + k += 1 + } + found = !foundArrayBuffer + } + } + j += 1 + } + } + if (found) { + arrayBuffer += elem1 + } + i += 1 + } + new GenericArrayData(arrayBuffer) + } + } + + override def nullSafeEval(input1: Any, input2: Any): Any = { + val array1 = input1.asInstanceOf[ArrayData] + val array2 = input2.asInstanceOf[ArrayData] + + evalIntersect(array1, array2) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val arrayData = classOf[ArrayData].getName + val i = ctx.freshName("i") + val value = ctx.freshName("value") + val size = ctx.freshName("size") + if (canUseSpecializedHashSet) { + val jt = CodeGenerator.javaType(elementType) + val ptName = CodeGenerator.primitiveTypeName(jt) + + nullSafeCodeGen(ctx, ev, (array1, array2) => { + val foundNullElement = ctx.freshName("foundNullElement") + val nullElementIndex = ctx.freshName("nullElementIndex") + val builder = ctx.freshName("builder") + val openHashSet = classOf[OpenHashSet[_]].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()" + val hashSet = ctx.freshName("hashSet") + val hashSetResult = ctx.freshName("hashSetResult") + val arrayBuilder = "scala.collection.mutable.ArrayBuilder" + val arrayBuilderClass = s"$arrayBuilder$$of$ptName" + val arrayBuilderClassTag = s"scala.reflect.ClassTag$$.MODULE$$.$ptName()" + + def withArray2NullCheck(body: String): String = + if (right.dataType.asInstanceOf[ArrayType].containsNull) { + if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array2.isNullAt($i)) { + | $foundNullElement = true; + |} else { + | $body + |} + """.stripMargin + } else { + // if array1's element is not nullable, we don't need to track the null element index. + s""" + |if (!$array2.isNullAt($i)) { + | $body + |} + """.stripMargin + } + } else { + body + } + + val writeArray2ToHashSet = withArray2NullCheck( + s""" + |$jt $value = ${genGetValue(array2, i)}; + |$hashSet.add$hsPostFix($hsValueCast$value); + """.stripMargin) + + def withArray1NullAssignment(body: String) = + if (left.dataType.asInstanceOf[ArrayType].containsNull) { + if (right.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array1.isNullAt($i)) { + | if ($foundNullElement) { + | $nullElementIndex = $size; + | $foundNullElement = false; + | $size++; + | $builder.$$plus$$eq($nullValueHolder); + | } + |} else { + | $body + |} + """.stripMargin + } else { + s""" + |if (!$array1.isNullAt($i)) { + | $body + |} + """.stripMargin + } + } else { + body + } + + val processArray1 = withArray1NullAssignment( + s""" + |$jt $value = ${genGetValue(array1, i)}; + |if ($hashSet.contains($hsValueCast$value) && + | !$hashSetResult.contains($hsValueCast$value)) { + | if (++$size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | break; + | } + | $hashSetResult.add$hsPostFix($hsValueCast$value); + | $builder.$$plus$$eq($value); + |} + """.stripMargin) + + // Only need to track null element index when result array's element is nullable. + val declareNullTrackVariables = if (dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |boolean $foundNullElement = false; + |int $nullElementIndex = -1; + """.stripMargin + } else { + "" + } + + s""" + |$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag); + |$openHashSet $hashSetResult = new $openHashSet$hsPostFix($classTag); + |$declareNullTrackVariables + |for (int $i = 0; $i < $array2.numElements(); $i++) { + | $writeArray2ToHashSet + |} + |$arrayBuilderClass $builder = + | ($arrayBuilderClass)$arrayBuilder.make($arrayBuilderClassTag); --- End diff -- nit: `new $arrayBuilderClass()` should work?
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org