Github user ueshin commented on a diff in the pull request:

    https://github.com/apache/spark/pull/21102#discussion_r207766511
  
    --- Diff: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
 ---
    @@ -3965,6 +4034,242 @@ object ArrayUnion {
       }
     }
     
    +/**
    + * Returns an array of the elements in the intersect of x and y, without 
duplicates
    + */
    +@ExpressionDescription(
    +  usage = """
    +  _FUNC_(array1, array2) - Returns an array of the elements in the 
intersection of array1 and
    +    array2, without duplicates.
    +  """,
    +  examples = """
    +    Examples:Fun
    +      > SELECT _FUNC_(array(1, 2, 3), array(1, 3, 5));
    +       array(1, 3)
    +  """,
    +  since = "2.4.0")
    +case class ArrayIntersect(left: Expression, right: Expression) extends 
ArraySetLike
    +  with ComplexTypeMergingExpression {
    +  override def dataType: DataType = {
    +    dataTypeCheck
    +    ArrayType(elementType,
    +      left.dataType.asInstanceOf[ArrayType].containsNull &&
    +        right.dataType.asInstanceOf[ArrayType].containsNull)
    +  }
    +
    +  @transient lazy val evalIntersect: (ArrayData, ArrayData) => ArrayData = 
{
    +    if (elementTypeSupportEquals) {
    +      (array1, array2) =>
    +        val hs = new OpenHashSet[Any]
    +        val hsResult = new OpenHashSet[Any]
    +        var foundNullElement = false
    +        var i = 0
    +        while (i < array2.numElements()) {
    +          if (array2.isNullAt(i)) {
    +            foundNullElement = true
    +          } else {
    +            val elem = array2.get(i, elementType)
    +            hs.add(elem)
    +          }
    +          i += 1
    +        }
    +        val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
    +        i = 0
    +        while (i < array1.numElements()) {
    +          if (array1.isNullAt(i)) {
    +            if (foundNullElement) {
    +              arrayBuffer += null
    +              foundNullElement = false
    +            }
    +          } else {
    +            val elem = array1.get(i, elementType)
    +            if (hs.contains(elem) && !hsResult.contains(elem)) {
    +              arrayBuffer += elem
    +              hsResult.add(elem)
    +            }
    +          }
    +          i += 1
    +        }
    +        new GenericArrayData(arrayBuffer)
    +    } else {
    +      (array1, array2) =>
    +        val arrayBuffer = new scala.collection.mutable.ArrayBuffer[Any]
    +        var alreadySeenNull = false
    +        var i = 0
    +        while (i < array1.numElements()) {
    +          var found = false
    +          val elem1 = array1.get(i, elementType)
    +          if (array1.isNullAt(i)) {
    +            if (!alreadySeenNull) {
    +              var j = 0
    +              while (!found && j < array2.numElements()) {
    +                found = array2.isNullAt(j)
    +                j += 1
    +              }
    +              // array2 is scanned only once for null element
    +              alreadySeenNull = true
    +            }
    +          } else {
    +            var j = 0
    +            while (!found && j < array2.numElements()) {
    +              if (!array2.isNullAt(j)) {
    +                val elem2 = array2.get(j, elementType)
    +                if (ordering.equiv(elem1, elem2)) {
    +                  // check whether elem1 is already stored in arrayBuffer
    +                  var foundArrayBuffer = false
    +                  var k = 0
    +                  while (!foundArrayBuffer && k < arrayBuffer.size) {
    +                    val va = arrayBuffer(k)
    +                    foundArrayBuffer = (va != null) && ordering.equiv(va, 
elem1)
    +                    k += 1
    +                  }
    +                  found = !foundArrayBuffer
    +                }
    +              }
    +              j += 1
    +            }
    +          }
    +          if (found) {
    +            arrayBuffer += elem1
    +          }
    +          i += 1
    +        }
    +        new GenericArrayData(arrayBuffer)
    +    }
    +  }
    +
    +  override def nullSafeEval(input1: Any, input2: Any): Any = {
    +    val array1 = input1.asInstanceOf[ArrayData]
    +    val array2 = input2.asInstanceOf[ArrayData]
    +
    +    evalIntersect(array1, array2)
    +  }
    +
    +  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    +    val arrayData = classOf[ArrayData].getName
    +    val i = ctx.freshName("i")
    +    val value = ctx.freshName("value")
    +    val size = ctx.freshName("size")
    +    if (canUseSpecializedHashSet) {
    +      val jt = CodeGenerator.javaType(elementType)
    +      val ptName = CodeGenerator.primitiveTypeName(jt)
    +
    +      nullSafeCodeGen(ctx, ev, (array1, array2) => {
    +        val foundNullElement = ctx.freshName("foundNullElement")
    +        val nullElementIndex = ctx.freshName("nullElementIndex")
    +        val builder = ctx.freshName("builder")
    +        val openHashSet = classOf[OpenHashSet[_]].getName
    +        val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
    +        val hashSet = ctx.freshName("hashSet")
    +        val hashSetResult = ctx.freshName("hashSetResult")
    +        val arrayBuilder = "scala.collection.mutable.ArrayBuilder"
    +        val arrayBuilderClass = s"$arrayBuilder$$of$ptName"
    +        val arrayBuilderClassTag = 
s"scala.reflect.ClassTag$$.MODULE$$.$ptName()"
    +
    +        def withArray2NullCheck(body: String): String =
    +          if (right.dataType.asInstanceOf[ArrayType].containsNull) {
    +            if (left.dataType.asInstanceOf[ArrayType].containsNull) {
    +              s"""
    +                 |if ($array2.isNullAt($i)) {
    +                 |  $foundNullElement = true;
    +                 |} else {
    +                 |  $body
    +                 |}
    +               """.stripMargin
    +            } else {
    +              // if array1's element is not nullable, we don't need to 
track the null element index.
    +              s"""
    +                 |if (!$array2.isNullAt($i)) {
    +                 |  $body
    +                 |}
    +               """.stripMargin
    +            }
    +          } else {
    +            body
    +          }
    +
    +        val writeArray2ToHashSet = withArray2NullCheck(
    +          s"""
    +             |$jt $value = ${genGetValue(array2, i)};
    +             |$hashSet.add$hsPostFix($hsValueCast$value);
    +           """.stripMargin)
    +
    +        def withArray1NullAssignment(body: String) =
    +          if (left.dataType.asInstanceOf[ArrayType].containsNull) {
    +            if (right.dataType.asInstanceOf[ArrayType].containsNull) {
    +              s"""
    +                 |if ($array1.isNullAt($i)) {
    +                 |  if ($foundNullElement) {
    +                 |    $nullElementIndex = $size;
    +                 |    $foundNullElement = false;
    +                 |    $size++;
    +                 |    $builder.$$plus$$eq($nullValueHolder);
    +                 |  }
    +                 |} else {
    +                 |  $body
    +                 |}
    +               """.stripMargin
    +            } else {
    +              s"""
    +                 |if (!$array1.isNullAt($i)) {
    +                 |  $body
    +                 |}
    +               """.stripMargin
    +            }
    +          } else {
    +            body
    +          }
    +
    +        val processArray1 = withArray1NullAssignment(
    +          s"""
    +             |$jt $value = ${genGetValue(array1, i)};
    +             |if ($hashSet.contains($hsValueCast$value) &&
    +             |    !$hashSetResult.contains($hsValueCast$value)) {
    +             |  if (++$size > 
${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
    +             |    break;
    +             |  }
    +             |  $hashSetResult.add$hsPostFix($hsValueCast$value);
    +             |  $builder.$$plus$$eq($value);
    +             |}
    +           """.stripMargin)
    +
    +        // Only need to track null element index when result array's 
element is nullable.
    +        val declareNullTrackVariables = if 
(dataType.asInstanceOf[ArrayType].containsNull) {
    +          s"""
    +             |boolean $foundNullElement = false;
    +             |int $nullElementIndex = -1;
    +           """.stripMargin
    +        } else {
    +          ""
    +        }
    +
    +        s"""
    +           |$openHashSet $hashSet = new $openHashSet$hsPostFix($classTag);
    +           |$openHashSet $hashSetResult = new 
$openHashSet$hsPostFix($classTag);
    +           |$declareNullTrackVariables
    +           |for (int $i = 0; $i < $array2.numElements(); $i++) {
    +           |  $writeArray2ToHashSet
    +           |}
    +           |$arrayBuilderClass $builder =
    +           |  
($arrayBuilderClass)$arrayBuilder.make($arrayBuilderClassTag);
    --- End diff --
    
    nit: `new $arrayBuilderClass()` should work?


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to