Github user kiszk commented on a diff in the pull request: https://github.com/apache/spark/pull/21103#discussion_r206008046 --- Diff: sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala --- @@ -3968,3 +3964,267 @@ object ArrayUnion { new GenericArrayData(arrayBuffer) } } + +/** + * Returns an array of the elements in the intersect of x and y, without duplicates + */ +@ExpressionDescription( + usage = """ + _FUNC_(array1, array2) - Returns an array of the elements in array1 but not in array2, + without duplicates. + """, + examples = """ + Examples: + > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5)); + array(2) + """, + since = "2.4.0") +case class ArrayExcept(left: Expression, right: Expression) extends ArraySetLike + with ComplexTypeMergingExpression { + override def dataType: DataType = { + dataTypeCheck + left.dataType + } + + @transient lazy val evalExcept: (ArrayData, ArrayData) => ArrayData = { + if (elementTypeSupportEquals) { + (array1, array2) => + val hs = new OpenHashSet[Any] + var notFoundNullElement = true + var i = 0 + while (i < array2.numElements()) { + if (array2.isNullAt(i)) { + notFoundNullElement = false + } else { + val elem = array2.get(i, elementType) + hs.add(elem) + } + i += 1 + } + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + i = 0 + while (i < array1.numElements()) { + if (array1.isNullAt(i)) { + if (notFoundNullElement) { + arrayBuffer += null + notFoundNullElement = false + } + } else { + val elem = array1.get(i, elementType) + if (!hs.contains(elem)) { + arrayBuffer += elem + hs.add(elem) + } + } + i += 1 + } + new GenericArrayData(arrayBuffer) + } else { + (array1, array2) => + val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any] + var scannedNullElements = false + var i = 0 + while (i < array1.numElements()) { + var found = false + val elem1 = array1.get(i, elementType) + if (elem1 == null) { + if (!scannedNullElements) { + var j = 0 + while (!found && j < array2.numElements()) { + found = array2.isNullAt(j) + j += 1 + } + // array2 is scanned only once for null element + scannedNullElements = true + } else { + found = true + } + } else { + var j = 0 + while (!found && j < array2.numElements()) { + val elem2 = array2.get(j, elementType) + if (elem2 != null) { + found = ordering.equiv(elem1, elem2) + } + j += 1 + } + if (!found) { + // check whether elem1 is already stored in arrayBuffer + var k = 0 + while (!found && k < arrayBuffer.size) { + val va = arrayBuffer(k) + found = (va != null) && ordering.equiv(va, elem1) + k += 1 + } + } + } + if (!found) { + arrayBuffer += elem1 + } + i += 1 + } + new GenericArrayData(arrayBuffer) + } + } + + override def nullSafeEval(input1: Any, input2: Any): Any = { + val array1 = input1.asInstanceOf[ArrayData] + val array2 = input2.asInstanceOf[ArrayData] + + evalExcept(array1, array2) + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val arrayData = classOf[ArrayData].getName + val i = ctx.freshName("i") + val pos = ctx.freshName("pos") + val value = ctx.freshName("value") + val hsValue = ctx.freshName("hsValue") + val size = ctx.freshName("size") + val (postFix, openHashElementType, hsJavaTypeName, genHsValue, + getter, setter, javaTypeName, arrayBuilder) = + if (elementTypeSupportEquals) { + elementType match { + case BooleanType | ByteType | ShortType | IntegerType => + val ptName = CodeGenerator.primitiveTypeName(elementType) + val unsafeArray = ctx.freshName("unsafeArray") + ("$mcI$sp", "Int", "int", + if (elementType != BooleanType) { + s"(int) $value" + } else { + s"$value ? 1 : 0;" + }, + s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType), + s""" + |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} + |${ev.value} = $unsafeArray; + """.stripMargin) + case LongType | FloatType | DoubleType => + val ptName = CodeGenerator.primitiveTypeName(elementType) + val unsafeArray = ctx.freshName("unsafeArray") + val signature = elementType match { + case LongType => "$mcJ$sp" + case FloatType => "$mcF$sp" + case DoubleType => "$mcD$sp" + } + (signature, CodeGenerator.boxedType(elementType), + CodeGenerator.javaType(elementType), value, + s"get$ptName($i)", s"set$ptName($pos, $value)", CodeGenerator.javaType(elementType), + s""" + |${ctx.createUnsafeArray(unsafeArray, size, elementType, s" $prettyName failed.")} + |${ev.value} = $unsafeArray; + """.stripMargin) + case _ => + val genericArrayData = classOf[GenericArrayData].getName + val et = ctx.addReferenceObj("elementType", elementType) + ("", "Object", "Object", value, + s"get($i, $et)", s"update($pos, $value)", "Object", + s"${ev.value} = new $genericArrayData(new Object[$size]);") + } + } else { + ("", "", "", "", "", "", "", "") + } + + nullSafeCodeGen(ctx, ev, (array1, array2) => { + if (openHashElementType != "") { + // Here, we ensure elementTypeSupportEquals is true + val notFoundNullElement = ctx.freshName("notFoundNullElement") + val openHashSet = classOf[OpenHashSet[_]].getName + val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()" + val hs = ctx.freshName("hs") + val arrayData = classOf[ArrayData].getName + val arrays = ctx.freshName("arrays") + val array = ctx.freshName("array") + val arrayDataIdx = ctx.freshName("arrayDataIdx") + + val array2NullCheck = if (right.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array2.isNullAt($i)) { + | $notFoundNullElement = false; + |} else + """.stripMargin + } else { + "" + } + val array1NullCheck = if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array1.isNullAt($i)) { + | if ($notFoundNullElement) { + | $size++; + | $notFoundNullElement = false; + | } + |} else + """.stripMargin + } else { + "" + } + val array1NullAssignment = if (left.dataType.asInstanceOf[ArrayType].containsNull) { + s""" + |if ($array1.isNullAt($i)) { + | if ($notFoundNullElement) { + | ${ev.value}.setNullAt($pos++); + | $notFoundNullElement = false; + | } + |} else + """.stripMargin + } else { + "" + } + + s""" + |$openHashSet $hs = new $openHashSet$postFix($classTag); + |boolean $notFoundNullElement = true; + |int $size = 0; + |for (int $i = 0; $i < $array2.numElements(); $i++) { + | $array2NullCheck + | { + | $javaTypeName $value = $array2.$getter; + | $hsJavaTypeName $hsValue = $genHsValue; + | $hs.add$postFix($hsValue); + | } + |} + |for (int $i = 0; $i < $array1.numElements(); $i++) { + | $array1NullCheck + | { + | $javaTypeName $value = $array1.$getter; + | $hsJavaTypeName $hsValue = $genHsValue; + | if (!$hs.contains($hsValue)) { + | $hs.add$postFix($hsValue); + | $size++; + | } + | } + |} + |$arrayBuilder + |$hs = new $openHashSet$postFix($classTag); + |$notFoundNullElement = true; + |int $pos = 0; + |for (int $i = 0; $i < $array2.numElements(); $i++) { --- End diff -- To use 'ArrayBuffer' involves boxing. Let me consider 'ArrayBuffer.ofInt' or others with special null handling.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org