Github user cloud-fan commented on a diff in the pull request:
https://github.com/apache/spark/pull/21103#discussion_r205999925
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
---
@@ -3968,3 +3964,267 @@ object ArrayUnion {
new GenericArrayData(arrayBuffer)
}
}
+
+/**
+ * Returns an array of the elements in the intersect of x and y, without
duplicates
+ */
+@ExpressionDescription(
+ usage = """
+ _FUNC_(array1, array2) - Returns an array of the elements in array1 but
not in array2,
+ without duplicates.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5));
+ array(2)
+ """,
+ since = "2.4.0")
+case class ArrayExcept(left: Expression, right: Expression) extends
ArraySetLike
+ with ComplexTypeMergingExpression {
+ override def dataType: DataType = {
+ dataTypeCheck
+ left.dataType
+ }
+
+ @transient lazy val evalExcept: (ArrayData, ArrayData) => ArrayData = {
+ if (elementTypeSupportEquals) {
+ (array1, array2) =>
+ val hs = new OpenHashSet[Any]
+ var notFoundNullElement = true
+ var i = 0
+ while (i < array2.numElements()) {
+ if (array2.isNullAt(i)) {
+ notFoundNullElement = false
+ } else {
+ val elem = array2.get(i, elementType)
+ hs.add(elem)
+ }
+ i += 1
+ }
+ val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
+ i = 0
+ while (i < array1.numElements()) {
+ if (array1.isNullAt(i)) {
+ if (notFoundNullElement) {
+ arrayBuffer += null
+ notFoundNullElement = false
+ }
+ } else {
+ val elem = array1.get(i, elementType)
+ if (!hs.contains(elem)) {
+ arrayBuffer += elem
+ hs.add(elem)
+ }
+ }
+ i += 1
+ }
+ new GenericArrayData(arrayBuffer)
+ } else {
+ (array1, array2) =>
+ val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
+ var scannedNullElements = false
+ var i = 0
+ while (i < array1.numElements()) {
+ var found = false
+ val elem1 = array1.get(i, elementType)
+ if (elem1 == null) {
+ if (!scannedNullElements) {
+ var j = 0
+ while (!found && j < array2.numElements()) {
+ found = array2.isNullAt(j)
+ j += 1
+ }
+ // array2 is scanned only once for null element
+ scannedNullElements = true
+ } else {
+ found = true
+ }
+ } else {
+ var j = 0
+ while (!found && j < array2.numElements()) {
+ val elem2 = array2.get(j, elementType)
+ if (elem2 != null) {
+ found = ordering.equiv(elem1, elem2)
+ }
+ j += 1
+ }
+ if (!found) {
+ // check whether elem1 is already stored in arrayBuffer
+ var k = 0
+ while (!found && k < arrayBuffer.size) {
+ val va = arrayBuffer(k)
+ found = (va != null) && ordering.equiv(va, elem1)
+ k += 1
+ }
+ }
+ }
+ if (!found) {
+ arrayBuffer += elem1
+ }
+ i += 1
+ }
+ new GenericArrayData(arrayBuffer)
+ }
+ }
+
+ override def nullSafeEval(input1: Any, input2: Any): Any = {
+ val array1 = input1.asInstanceOf[ArrayData]
+ val array2 = input2.asInstanceOf[ArrayData]
+
+ evalExcept(array1, array2)
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val arrayData = classOf[ArrayData].getName
+ val i = ctx.freshName("i")
+ val pos = ctx.freshName("pos")
+ val value = ctx.freshName("value")
+ val hsValue = ctx.freshName("hsValue")
+ val size = ctx.freshName("size")
+ val (postFix, openHashElementType, hsJavaTypeName, genHsValue,
+ getter, setter, javaTypeName, arrayBuilder) =
+ if (elementTypeSupportEquals) {
+ elementType match {
+ case BooleanType | ByteType | ShortType | IntegerType =>
+ val ptName = CodeGenerator.primitiveTypeName(elementType)
+ val unsafeArray = ctx.freshName("unsafeArray")
+ ("$mcI$sp", "Int", "int",
+ if (elementType != BooleanType) {
+ s"(int) $value"
+ } else {
+ s"$value ? 1 : 0;"
+ },
+ s"get$ptName($i)", s"set$ptName($pos, $value)",
CodeGenerator.javaType(elementType),
+ s"""
+ |${ctx.createUnsafeArray(unsafeArray, size, elementType,
s" $prettyName failed.")}
+ |${ev.value} = $unsafeArray;
+ """.stripMargin)
+ case LongType | FloatType | DoubleType =>
+ val ptName = CodeGenerator.primitiveTypeName(elementType)
+ val unsafeArray = ctx.freshName("unsafeArray")
+ val signature = elementType match {
+ case LongType => "$mcJ$sp"
+ case FloatType => "$mcF$sp"
+ case DoubleType => "$mcD$sp"
+ }
+ (signature, CodeGenerator.boxedType(elementType),
+ CodeGenerator.javaType(elementType), value,
+ s"get$ptName($i)", s"set$ptName($pos, $value)",
CodeGenerator.javaType(elementType),
+ s"""
+ |${ctx.createUnsafeArray(unsafeArray, size, elementType,
s" $prettyName failed.")}
+ |${ev.value} = $unsafeArray;
+ """.stripMargin)
+ case _ =>
+ val genericArrayData = classOf[GenericArrayData].getName
+ val et = ctx.addReferenceObj("elementType", elementType)
+ ("", "Object", "Object", value,
+ s"get($i, $et)", s"update($pos, $value)", "Object",
+ s"${ev.value} = new $genericArrayData(new Object[$size]);")
+ }
+ } else {
+ ("", "", "", "", "", "", "", "")
+ }
+
+ nullSafeCodeGen(ctx, ev, (array1, array2) => {
+ if (openHashElementType != "") {
+ // Here, we ensure elementTypeSupportEquals is true
+ val notFoundNullElement = ctx.freshName("notFoundNullElement")
+ val openHashSet = classOf[OpenHashSet[_]].getName
+ val classTag =
s"scala.reflect.ClassTag$$.MODULE$$.$openHashElementType()"
+ val hs = ctx.freshName("hs")
+ val arrayData = classOf[ArrayData].getName
+ val arrays = ctx.freshName("arrays")
+ val array = ctx.freshName("array")
+ val arrayDataIdx = ctx.freshName("arrayDataIdx")
+
+ val array2NullCheck = if
(right.dataType.asInstanceOf[ArrayType].containsNull) {
+ s"""
+ |if ($array2.isNullAt($i)) {
+ | $notFoundNullElement = false;
+ |} else
+ """.stripMargin
+ } else {
+ ""
+ }
+ val array1NullCheck = if
(left.dataType.asInstanceOf[ArrayType].containsNull) {
+ s"""
+ |if ($array1.isNullAt($i)) {
+ | if ($notFoundNullElement) {
+ | $size++;
+ | $notFoundNullElement = false;
+ | }
+ |} else
+ """.stripMargin
+ } else {
+ ""
+ }
+ val array1NullAssignment = if
(left.dataType.asInstanceOf[ArrayType].containsNull) {
+ s"""
+ |if ($array1.isNullAt($i)) {
+ | if ($notFoundNullElement) {
+ | ${ev.value}.setNullAt($pos++);
+ | $notFoundNullElement = false;
+ | }
+ |} else
+ """.stripMargin
+ } else {
+ ""
+ }
+
+ s"""
+ |$openHashSet $hs = new $openHashSet$postFix($classTag);
+ |boolean $notFoundNullElement = true;
+ |int $size = 0;
+ |for (int $i = 0; $i < $array2.numElements(); $i++) {
+ | $array2NullCheck
+ | {
+ | $javaTypeName $value = $array2.$getter;
+ | $hsJavaTypeName $hsValue = $genHsValue;
+ | $hs.add$postFix($hsValue);
+ | }
+ |}
+ |for (int $i = 0; $i < $array1.numElements(); $i++) {
+ | $array1NullCheck
+ | {
+ | $javaTypeName $value = $array1.$getter;
+ | $hsJavaTypeName $hsValue = $genHsValue;
+ | if (!$hs.contains($hsValue)) {
+ | $hs.add$postFix($hsValue);
+ | $size++;
+ | }
+ | }
+ |}
+ |$arrayBuilder
+ |$hs = new $openHashSet$postFix($classTag);
+ |$notFoundNullElement = true;
+ |int $pos = 0;
+ |for (int $i = 0; $i < $array2.numElements(); $i++) {
--- End diff --
why add `array2` to the hash set again?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]