AngersZhuuuu commented on a change in pull request #30243:
URL: https://github.com/apache/spark/pull/30243#discussion_r528062328



##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
##########
@@ -3957,3 +3957,194 @@ case class ArrayExcept(left: Expression, right: 
Expression) extends ArrayBinaryL
 
   override def prettyName: String = "array_except"
 }
+
+/**
+ * Checks if the array (left) has the array (right)
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(array1, array2) - Returns true if the array1 contains the 
array2.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(array(1, 2, 3), array(2));
+       true
+  """,
+  group = "array_funcs",
+  since = "3.1.0")
+case class ArrayContainsArray(left: Expression, right: Expression)
+  extends BinaryArrayExpressionWithImplicitCast with ArraySetLike with 
NullIntolerant {
+
+  override def dataType: DataType = BooleanType
+
+  override def et: DataType = elementType
+
+  override def dt: DataType = dataType
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val typeCheckResult = super.checkInputDataTypes()
+    if (typeCheckResult.isSuccess) {
+      TypeUtils.checkForOrderingExpr(et, s"function $prettyName")
+    } else {
+      typeCheckResult
+    }
+  }
+
+  @transient lazy val evalContains: (ArrayData, ArrayData) => Boolean = {
+    if (TypeUtils.typeWithProperEquals(elementType)) {
+      (array1, array2) =>
+        if (array2.numElements() == 0) {
+          true
+        } else if (array1.numElements() == 0) {
+          false
+        } else {
+          val hs = new OpenHashSet[Any]
+          var result = true
+          var foundNullElement = false
+          var i = 0
+          while (i < array1.numElements()) {
+            if (array1.isNullAt(i) && !foundNullElement) {
+              foundNullElement = true
+            } else {
+              val elem = array1.get(i, elementType)
+              hs.add(elem)
+            }
+            i += 1
+          }
+          i = 0
+          while (i < array2.numElements() && result) {
+            if (array2.isNullAt(i)) {
+              if (!foundNullElement) {
+                result = false
+              }
+            } else {
+              val elem = array2.get(i, elementType)
+              if (!hs.contains(elem)) {
+                result = false
+              }
+            }
+            i += 1
+          }
+          result
+        }
+    } else {
+      (array1, array2) =>
+        if (array2.numElements() == 0) {
+          true
+        } else if (array1.numElements() == 0) {
+          false
+        } else {
+          var alreadySeenNull = false
+          var i = 0
+          var elementFound = true
+          while (elementFound && i < array2.numElements()) {
+            var found = false
+            val elem2 = array2.get(i, elementType)
+            if (array2.isNullAt(i)) {
+              if (!alreadySeenNull) {
+                var j = 0
+                while (!found && j < array1.numElements()) {
+                  found = array1.isNullAt(j)
+                  j += 1
+                }
+                // array1 is scanned only once for null element
+                alreadySeenNull = true
+              }
+            } else {
+              var j = 0
+              while (!found && j < array2.numElements()) {
+                if (!array1.isNullAt(j)) {
+                  val elem1 = array1.get(j, elementType)
+                  if (ordering.equiv(elem2, elem1)) {
+                    found = true
+                  }
+                }
+                j += 1
+              }
+            }
+            if (!found) {
+              elementFound = false
+            }
+            i += 1
+          }
+          elementFound
+        }
+    }
+  }
+
+  override def nullSafeEval(input1: Any, input2: Any): Any = {
+    val array1 = input1.asInstanceOf[ArrayData]
+    val array2 = input2.asInstanceOf[ArrayData]
+
+    evalContains(array1, array2)
+  }
+
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
+    val i = ctx.freshName("i")
+    val value = ctx.freshName("value")
+    if (canUseSpecializedHashSet) {
+      val jt = CodeGenerator.javaType(elementType)
+
+      nullSafeCodeGen(ctx, ev, (array1, array2) => {
+        val result = ctx.freshName("result")
+        val foundNullElement = ctx.freshName("foundNullElement")
+        val openHashSet = classOf[OpenHashSet[_]].getName
+        val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
+        val hashSet = ctx.freshName("hashSet")
+
+        def withArray1NullCheck(body: String): String =
+          s"""
+             |if ($array1.isNullAt($i) && !$foundNullElement) {

Review comment:
       > If there is more than one null element in `array1`, does this code 
work?
   
   Update UT and it works. also change t as below like  
https://github.com/apache/spark/pull/30243#discussion_r528061382
   ```
                |if ($array1.isNullAt($i)) {
                |  if (!$foundNullElement) {
                |    $foundNullElement = true;
                |  }
   ```

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
##########
@@ -3957,3 +3957,194 @@ case class ArrayExcept(left: Expression, right: 
Expression) extends ArrayBinaryL
 
   override def prettyName: String = "array_except"
 }
+
+/**
+ * Checks if the array (left) has the array (right)
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(array1, array2) - Returns true if the array1 contains the 
array2.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(array(1, 2, 3), array(2));
+       true
+  """,
+  group = "array_funcs",
+  since = "3.1.0")
+case class ArrayContainsArray(left: Expression, right: Expression)
+  extends BinaryArrayExpressionWithImplicitCast with ArraySetLike with 
NullIntolerant {
+
+  override def dataType: DataType = BooleanType
+
+  override def et: DataType = elementType
+
+  override def dt: DataType = dataType
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val typeCheckResult = super.checkInputDataTypes()
+    if (typeCheckResult.isSuccess) {
+      TypeUtils.checkForOrderingExpr(et, s"function $prettyName")
+    } else {
+      typeCheckResult
+    }
+  }
+
+  @transient lazy val evalContains: (ArrayData, ArrayData) => Boolean = {
+    if (TypeUtils.typeWithProperEquals(elementType)) {
+      (array1, array2) =>
+        if (array2.numElements() == 0) {
+          true
+        } else if (array1.numElements() == 0) {
+          false
+        } else {
+          val hs = new OpenHashSet[Any]
+          var result = true
+          var foundNullElement = false
+          var i = 0
+          while (i < array1.numElements()) {
+            if (array1.isNullAt(i) && !foundNullElement) {
+              foundNullElement = true
+            } else {
+              val elem = array1.get(i, elementType)
+              hs.add(elem)
+            }
+            i += 1
+          }
+          i = 0
+          while (i < array2.numElements() && result) {
+            if (array2.isNullAt(i)) {
+              if (!foundNullElement) {
+                result = false
+              }
+            } else {
+              val elem = array2.get(i, elementType)
+              if (!hs.contains(elem)) {
+                result = false
+              }
+            }
+            i += 1
+          }
+          result
+        }
+    } else {
+      (array1, array2) =>
+        if (array2.numElements() == 0) {
+          true
+        } else if (array1.numElements() == 0) {
+          false
+        } else {
+          var alreadySeenNull = false
+          var i = 0
+          var elementFound = true
+          while (elementFound && i < array2.numElements()) {
+            var found = false
+            val elem2 = array2.get(i, elementType)
+            if (array2.isNullAt(i)) {
+              if (!alreadySeenNull) {
+                var j = 0
+                while (!found && j < array1.numElements()) {
+                  found = array1.isNullAt(j)
+                  j += 1
+                }
+                // array1 is scanned only once for null element
+                alreadySeenNull = true
+              }
+            } else {
+              var j = 0
+              while (!found && j < array2.numElements()) {
+                if (!array1.isNullAt(j)) {
+                  val elem1 = array1.get(j, elementType)
+                  if (ordering.equiv(elem2, elem1)) {
+                    found = true
+                  }
+                }
+                j += 1
+              }
+            }
+            if (!found) {
+              elementFound = false
+            }
+            i += 1
+          }
+          elementFound
+        }
+    }
+  }
+
+  override def nullSafeEval(input1: Any, input2: Any): Any = {
+    val array1 = input1.asInstanceOf[ArrayData]
+    val array2 = input2.asInstanceOf[ArrayData]
+
+    evalContains(array1, array2)
+  }
+
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
+    val i = ctx.freshName("i")
+    val value = ctx.freshName("value")
+    if (canUseSpecializedHashSet) {
+      val jt = CodeGenerator.javaType(elementType)
+
+      nullSafeCodeGen(ctx, ev, (array1, array2) => {
+        val result = ctx.freshName("result")
+        val foundNullElement = ctx.freshName("foundNullElement")
+        val openHashSet = classOf[OpenHashSet[_]].getName
+        val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
+        val hashSet = ctx.freshName("hashSet")
+
+        def withArray1NullCheck(body: String): String =
+          s"""
+             |if ($array1.isNullAt($i) && !$foundNullElement) {
+             |  $foundNullElement = true;
+             |} else {
+             |  $body
+             |}
+               """.stripMargin
+
+        val writeArray1ToHashSet = withArray1NullCheck(
+          s"""
+             |$jt $value = ${genGetValue(array1, i)};
+             |$hashSet.add$hsPostFix($hsValueCast$value);
+           """.stripMargin)
+
+        val processArray2 =
+          s"""
+             |if ($array2.isNullAt($i)) {
+             |  if (!$foundNullElement) {
+             |    $result = false;
+             |  }
+             |} else {
+             |  $jt $value = ${genGetValue(array2, i)};
+             |  if (!$hashSet.contains($hsValueCast$value)) {
+             |   $result = false;

Review comment:
       > nit: indentation
   
   Done

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
##########
@@ -3957,3 +3957,194 @@ case class ArrayExcept(left: Expression, right: 
Expression) extends ArrayBinaryL
 
   override def prettyName: String = "array_except"
 }
+
+/**
+ * Checks if the array (left) has the array (right)
+ */
+@ExpressionDescription(
+  usage = "_FUNC_(array1, array2) - Returns true if the array1 contains the 
array2.",
+  examples = """
+    Examples:
+      > SELECT _FUNC_(array(1, 2, 3), array(2));
+       true
+  """,
+  group = "array_funcs",
+  since = "3.1.0")
+case class ArrayContainsArray(left: Expression, right: Expression)
+  extends BinaryArrayExpressionWithImplicitCast with ArraySetLike with 
NullIntolerant {
+
+  override def dataType: DataType = BooleanType
+
+  override def et: DataType = elementType
+
+  override def dt: DataType = dataType
+
+  override def checkInputDataTypes(): TypeCheckResult = {
+    val typeCheckResult = super.checkInputDataTypes()
+    if (typeCheckResult.isSuccess) {
+      TypeUtils.checkForOrderingExpr(et, s"function $prettyName")
+    } else {
+      typeCheckResult
+    }
+  }
+
+  @transient lazy val evalContains: (ArrayData, ArrayData) => Boolean = {
+    if (TypeUtils.typeWithProperEquals(elementType)) {
+      (array1, array2) =>
+        if (array2.numElements() == 0) {
+          true
+        } else if (array1.numElements() == 0) {
+          false
+        } else {
+          val hs = new OpenHashSet[Any]
+          var result = true
+          var foundNullElement = false
+          var i = 0
+          while (i < array1.numElements()) {
+            if (array1.isNullAt(i) && !foundNullElement) {
+              foundNullElement = true
+            } else {
+              val elem = array1.get(i, elementType)
+              hs.add(elem)
+            }
+            i += 1
+          }
+          i = 0
+          while (i < array2.numElements() && result) {
+            if (array2.isNullAt(i)) {
+              if (!foundNullElement) {
+                result = false
+              }
+            } else {
+              val elem = array2.get(i, elementType)
+              if (!hs.contains(elem)) {
+                result = false
+              }
+            }
+            i += 1
+          }
+          result
+        }
+    } else {
+      (array1, array2) =>
+        if (array2.numElements() == 0) {
+          true
+        } else if (array1.numElements() == 0) {
+          false
+        } else {
+          var alreadySeenNull = false
+          var i = 0
+          var elementFound = true
+          while (elementFound && i < array2.numElements()) {
+            var found = false
+            val elem2 = array2.get(i, elementType)
+            if (array2.isNullAt(i)) {
+              if (!alreadySeenNull) {
+                var j = 0
+                while (!found && j < array1.numElements()) {
+                  found = array1.isNullAt(j)
+                  j += 1
+                }
+                // array1 is scanned only once for null element
+                alreadySeenNull = true
+              }
+            } else {
+              var j = 0
+              while (!found && j < array2.numElements()) {
+                if (!array1.isNullAt(j)) {
+                  val elem1 = array1.get(j, elementType)
+                  if (ordering.equiv(elem2, elem1)) {
+                    found = true
+                  }
+                }
+                j += 1
+              }
+            }
+            if (!found) {
+              elementFound = false
+            }
+            i += 1
+          }
+          elementFound
+        }
+    }
+  }
+
+  override def nullSafeEval(input1: Any, input2: Any): Any = {
+    val array1 = input1.asInstanceOf[ArrayData]
+    val array2 = input2.asInstanceOf[ArrayData]
+
+    evalContains(array1, array2)
+  }
+
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
+    val i = ctx.freshName("i")
+    val value = ctx.freshName("value")
+    if (canUseSpecializedHashSet) {
+      val jt = CodeGenerator.javaType(elementType)
+
+      nullSafeCodeGen(ctx, ev, (array1, array2) => {
+        val result = ctx.freshName("result")
+        val foundNullElement = ctx.freshName("foundNullElement")
+        val openHashSet = classOf[OpenHashSet[_]].getName
+        val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
+        val hashSet = ctx.freshName("hashSet")
+
+        def withArray1NullCheck(body: String): String =
+          s"""
+             |if ($array1.isNullAt($i) && !$foundNullElement) {
+             |  $foundNullElement = true;
+             |} else {
+             |  $body
+             |}
+               """.stripMargin

Review comment:
       > nit: indentation
   
   Done




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to