kiszk commented on a change in pull request #30243:
URL: https://github.com/apache/spark/pull/30243#discussion_r528060697
##########
File path:
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
##########
@@ -3957,3 +3957,194 @@ case class ArrayExcept(left: Expression, right:
Expression) extends ArrayBinaryL
override def prettyName: String = "array_except"
}
+
+/**
+ * Checks if the array (left) has the array (right)
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(array1, array2) - Returns true if the array1 contains the
array2.",
+ examples = """
+ Examples:
+ > SELECT _FUNC_(array(1, 2, 3), array(2));
+ true
+ """,
+ group = "array_funcs",
+ since = "3.1.0")
+case class ArrayContainsArray(left: Expression, right: Expression)
+ extends BinaryArrayExpressionWithImplicitCast with ArraySetLike with
NullIntolerant {
+
+ override def dataType: DataType = BooleanType
+
+ override def et: DataType = elementType
+
+ override def dt: DataType = dataType
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val typeCheckResult = super.checkInputDataTypes()
+ if (typeCheckResult.isSuccess) {
+ TypeUtils.checkForOrderingExpr(et, s"function $prettyName")
+ } else {
+ typeCheckResult
+ }
+ }
+
+ @transient lazy val evalContains: (ArrayData, ArrayData) => Boolean = {
+ if (TypeUtils.typeWithProperEquals(elementType)) {
+ (array1, array2) =>
+ if (array2.numElements() == 0) {
+ true
+ } else if (array1.numElements() == 0) {
+ false
+ } else {
+ val hs = new OpenHashSet[Any]
+ var result = true
+ var foundNullElement = false
+ var i = 0
+ while (i < array1.numElements()) {
+ if (array1.isNullAt(i) && !foundNullElement) {
+ foundNullElement = true
+ } else {
+ val elem = array1.get(i, elementType)
+ hs.add(elem)
+ }
+ i += 1
+ }
+ i = 0
+ while (i < array2.numElements() && result) {
+ if (array2.isNullAt(i)) {
+ if (!foundNullElement) {
+ result = false
+ }
+ } else {
+ val elem = array2.get(i, elementType)
+ if (!hs.contains(elem)) {
+ result = false
+ }
+ }
+ i += 1
+ }
+ result
+ }
+ } else {
+ (array1, array2) =>
+ if (array2.numElements() == 0) {
+ true
+ } else if (array1.numElements() == 0) {
+ false
+ } else {
+ var alreadySeenNull = false
+ var i = 0
+ var elementFound = true
+ while (elementFound && i < array2.numElements()) {
+ var found = false
+ val elem2 = array2.get(i, elementType)
+ if (array2.isNullAt(i)) {
+ if (!alreadySeenNull) {
+ var j = 0
+ while (!found && j < array1.numElements()) {
+ found = array1.isNullAt(j)
+ j += 1
+ }
+ // array1 is scanned only once for null element
+ alreadySeenNull = true
+ }
+ } else {
+ var j = 0
+ while (!found && j < array2.numElements()) {
+ if (!array1.isNullAt(j)) {
+ val elem1 = array1.get(j, elementType)
+ if (ordering.equiv(elem2, elem1)) {
+ found = true
+ }
+ }
+ j += 1
+ }
+ }
+ if (!found) {
+ elementFound = false
+ }
+ i += 1
+ }
+ elementFound
+ }
+ }
+ }
+
+ override def nullSafeEval(input1: Any, input2: Any): Any = {
+ val array1 = input1.asInstanceOf[ArrayData]
+ val array2 = input2.asInstanceOf[ArrayData]
+
+ evalContains(array1, array2)
+ }
+
+ override protected def doGenCode(ctx: CodegenContext, ev: ExprCode):
ExprCode = {
+ val i = ctx.freshName("i")
+ val value = ctx.freshName("value")
+ if (canUseSpecializedHashSet) {
+ val jt = CodeGenerator.javaType(elementType)
+
+ nullSafeCodeGen(ctx, ev, (array1, array2) => {
+ val result = ctx.freshName("result")
+ val foundNullElement = ctx.freshName("foundNullElement")
+ val openHashSet = classOf[OpenHashSet[_]].getName
+ val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
+ val hashSet = ctx.freshName("hashSet")
+
+ def withArray1NullCheck(body: String): String =
+ s"""
+ |if ($array1.isNullAt($i) && !$foundNullElement) {
+ | $foundNullElement = true;
+ |} else {
+ | $body
+ |}
+ """.stripMargin
+
+ val writeArray1ToHashSet = withArray1NullCheck(
+ s"""
+ |$jt $value = ${genGetValue(array1, i)};
+ |$hashSet.add$hsPostFix($hsValueCast$value);
+ """.stripMargin)
+
+ val processArray2 =
+ s"""
+ |if ($array2.isNullAt($i)) {
+ | if (!$foundNullElement) {
+ | $result = false;
+ | }
+ |} else {
+ | $jt $value = ${genGetValue(array2, i)};
+ | if (!$hashSet.contains($hsValueCast$value)) {
+ | $result = false;
Review comment:
nit: indentation
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]