Github user cloud-fan commented on a diff in the pull request:
https://github.com/apache/spark/pull/21103#discussion_r205999733
--- Diff:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
---
@@ -3968,3 +3964,267 @@ object ArrayUnion {
new GenericArrayData(arrayBuffer)
}
}
+
+/**
+ * Returns an array of the elements in the intersect of x and y, without
duplicates
+ */
+@ExpressionDescription(
+ usage = """
+ _FUNC_(array1, array2) - Returns an array of the elements in array1 but
not in array2,
+ without duplicates.
+ """,
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5));
+ array(2)
+ """,
+ since = "2.4.0")
+case class ArrayExcept(left: Expression, right: Expression) extends
ArraySetLike
+ with ComplexTypeMergingExpression {
+ override def dataType: DataType = {
+ dataTypeCheck
+ left.dataType
+ }
+
+ @transient lazy val evalExcept: (ArrayData, ArrayData) => ArrayData = {
+ if (elementTypeSupportEquals) {
+ (array1, array2) =>
+ val hs = new OpenHashSet[Any]
+ var notFoundNullElement = true
+ var i = 0
+ while (i < array2.numElements()) {
+ if (array2.isNullAt(i)) {
+ notFoundNullElement = false
+ } else {
+ val elem = array2.get(i, elementType)
+ hs.add(elem)
+ }
+ i += 1
+ }
+ val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
+ i = 0
+ while (i < array1.numElements()) {
+ if (array1.isNullAt(i)) {
+ if (notFoundNullElement) {
+ arrayBuffer += null
+ notFoundNullElement = false
+ }
+ } else {
+ val elem = array1.get(i, elementType)
+ if (!hs.contains(elem)) {
+ arrayBuffer += elem
+ hs.add(elem)
+ }
+ }
+ i += 1
+ }
+ new GenericArrayData(arrayBuffer)
+ } else {
+ (array1, array2) =>
+ val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
+ var scannedNullElements = false
+ var i = 0
+ while (i < array1.numElements()) {
+ var found = false
+ val elem1 = array1.get(i, elementType)
+ if (elem1 == null) {
+ if (!scannedNullElements) {
+ var j = 0
+ while (!found && j < array2.numElements()) {
+ found = array2.isNullAt(j)
+ j += 1
+ }
+ // array2 is scanned only once for null element
+ scannedNullElements = true
+ } else {
+ found = true
+ }
+ } else {
+ var j = 0
+ while (!found && j < array2.numElements()) {
+ val elem2 = array2.get(j, elementType)
+ if (elem2 != null) {
+ found = ordering.equiv(elem1, elem2)
+ }
+ j += 1
+ }
+ if (!found) {
+ // check whether elem1 is already stored in arrayBuffer
+ var k = 0
+ while (!found && k < arrayBuffer.size) {
+ val va = arrayBuffer(k)
+ found = (va != null) && ordering.equiv(va, elem1)
+ k += 1
+ }
+ }
+ }
+ if (!found) {
+ arrayBuffer += elem1
+ }
+ i += 1
+ }
+ new GenericArrayData(arrayBuffer)
+ }
+ }
+
+ override def nullSafeEval(input1: Any, input2: Any): Any = {
+ val array1 = input1.asInstanceOf[ArrayData]
+ val array2 = input2.asInstanceOf[ArrayData]
+
+ evalExcept(array1, array2)
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val arrayData = classOf[ArrayData].getName
+ val i = ctx.freshName("i")
+ val pos = ctx.freshName("pos")
+ val value = ctx.freshName("value")
+ val hsValue = ctx.freshName("hsValue")
+ val size = ctx.freshName("size")
+ val (postFix, openHashElementType, hsJavaTypeName, genHsValue,
+ getter, setter, javaTypeName, arrayBuilder) =
+ if (elementTypeSupportEquals) {
+ elementType match {
+ case BooleanType | ByteType | ShortType | IntegerType =>
+ val ptName = CodeGenerator.primitiveTypeName(elementType)
+ val unsafeArray = ctx.freshName("unsafeArray")
+ ("$mcI$sp", "Int", "int",
+ if (elementType != BooleanType) {
+ s"(int) $value"
+ } else {
+ s"$value ? 1 : 0;"
+ },
+ s"get$ptName($i)", s"set$ptName($pos, $value)",
CodeGenerator.javaType(elementType),
+ s"""
+ |${ctx.createUnsafeArray(unsafeArray, size, elementType,
s" $prettyName failed.")}
+ |${ev.value} = $unsafeArray;
+ """.stripMargin)
+ case LongType | FloatType | DoubleType =>
+ val ptName = CodeGenerator.primitiveTypeName(elementType)
+ val unsafeArray = ctx.freshName("unsafeArray")
+ val signature = elementType match {
+ case LongType => "$mcJ$sp"
+ case FloatType => "$mcF$sp"
+ case DoubleType => "$mcD$sp"
+ }
+ (signature, CodeGenerator.boxedType(elementType),
+ CodeGenerator.javaType(elementType), value,
+ s"get$ptName($i)", s"set$ptName($pos, $value)",
CodeGenerator.javaType(elementType),
+ s"""
+ |${ctx.createUnsafeArray(unsafeArray, size, elementType,
s" $prettyName failed.")}
+ |${ev.value} = $unsafeArray;
+ """.stripMargin)
+ case _ =>
+ val genericArrayData = classOf[GenericArrayData].getName
+ val et = ctx.addReferenceObj("elementType", elementType)
+ ("", "Object", "Object", value,
+ s"get($i, $et)", s"update($pos, $value)", "Object",
+ s"${ev.value} = new $genericArrayData(new Object[$size]);")
+ }
+ } else {
+ ("", "", "", "", "", "", "", "")
+ }
+
+ nullSafeCodeGen(ctx, ev, (array1, array2) => {
+ if (openHashElementType != "") {
--- End diff --
a better way to organize it
```
if (elementTypeSupportEquals) {
...
nullSafeCodeGen(...)
} else {
nullSafeCodeGen(...)
}
```
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]