Github user kiszk commented on a diff in the pull request:
https://github.com/apache/spark/pull/21102#discussion_r205959224
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
---
@@ -3968,3 +3964,234 @@ object ArrayUnion {
new GenericArrayData(arrayBuffer)
}
}
+
+/**
+ * Returns an array of the elements in the intersect of x and y, without
duplicates
+ */
+@ExpressionDescription(
+ usage = """
+ _FUNC_(array1, array2) - Returns an array of the elements in the
intersection of array1 and
+ array2, without duplicates.
+ """,
+ examples = """
+ Examples:Fun
+ > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5));
+ array(1, 3)
+ """,
+ since = "2.4.0")
+case class ArrayIntersect(left: Expression, right: Expression) extends
ArraySetLike {
+ override def dataType: DataType = ArrayType(elementType,
+ left.dataType.asInstanceOf[ArrayType].containsNull &&
+ right.dataType.asInstanceOf[ArrayType].containsNull)
+
+ @transient lazy val evalIntersect: (ArrayData, ArrayData) => ArrayData =
{
+ if (elementTypeSupportEquals) {
+ (array1, array2) =>
+ val hs = new OpenHashSet[Any]
+ val hsResult = new OpenHashSet[Any]
+ var foundNullElement = false
+ var i = 0
+ while (i < array2.numElements()) {
+ if (array2.isNullAt(i)) {
+ foundNullElement = true
+ } else {
+ val elem = array2.get(i, elementType)
+ hs.add(elem)
+ }
+ i += 1
+ }
+ val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
+ i = 0
+ while (i < array1.numElements()) {
+ if (array1.isNullAt(i)) {
+ if (foundNullElement) {
+ arrayBuffer += null
+ foundNullElement = false
+ }
+ } else {
+ val elem = array1.get(i, elementType)
+ if (hs.contains(elem) && !hsResult.contains(elem)) {
+ arrayBuffer += elem
+ hsResult.add(elem)
+ }
+ }
+ i += 1
+ }
+ new GenericArrayData(arrayBuffer)
+ } else {
+ (array1, array2) =>
+ val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
+ var alreadySeenNull = false
+ var i = 0
+ while (i < array1.numElements()) {
+ var found = false
+ val elem1 = array1.get(i, elementType)
+ if (array1.isNullAt(i)) {
+ if (!alreadySeenNull) {
+ var j = 0
+ while (!found && j < array2.numElements()) {
+ found = array2.isNullAt(j)
+ j += 1
+ }
+ // array2 is scanned only once for null element
+ alreadySeenNull = true
+ }
+ } else {
+ var j = 0
+ while (!found && j < array2.numElements()) {
+ if (!array2.isNullAt(j)) {
+ val elem2 = array2.get(j, elementType)
+ if (ordering.equiv(elem1, elem2)) {
+ // check whether elem1 is already stored in arrayBuffer
+ var foundArrayBuffer = false
+ var k = 0
+ while (!foundArrayBuffer && k < arrayBuffer.size) {
+ val va = arrayBuffer(k)
+ foundArrayBuffer = (va != null) && ordering.equiv(va,
elem1)
+ k += 1
+ }
+ found = !foundArrayBuffer
+ }
+ }
+ j += 1
+ }
+ }
+ if (found) {
+ arrayBuffer += elem1
+ }
+ i += 1
+ }
+ new GenericArrayData(arrayBuffer)
+ }
+ }
+
+ override def nullSafeEval(input1: Any, input2: Any): Any = {
+ val array1 = input1.asInstanceOf[ArrayData]
+ val array2 = input2.asInstanceOf[ArrayData]
+
+ evalIntersect(array1, array2)
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val arrayData = classOf[ArrayData].getName
+ val i = ctx.freshName("i")
--- End diff --
It would be good to refactor as a method from L4077 to L4124 since this
part can be used among `union`, `except`, and `intersect`.
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]