Github user ueshin commented on a diff in the pull request:
https://github.com/apache/spark/pull/21069#discussion_r191958470
--- Diff:
sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala
---
@@ -552,4 +552,60 @@ class CollectionExpressionsSuite extends SparkFunSuite
with ExpressionEvalHelper
checkEvaluation(ArrayRepeat(strArray, Literal(2)), Seq(Seq("hi",
"hola"), Seq("hi", "hola")))
checkEvaluation(ArrayRepeat(Literal("hi"), Literal(null,
IntegerType)), null)
}
+
+ test("Array remove") {
+ val a0 = Literal.create(Seq(1, 2, 3, 2, 2, 5), ArrayType(IntegerType))
+ val a1 = Literal.create(Seq("b", "a", "a", "c", "b"),
ArrayType(StringType))
+ val a2 = Literal.create(Seq[String](null, "", null, ""),
ArrayType(StringType))
+ val a3 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType))
+ val a4 = Literal.create(null, ArrayType(StringType))
+ val a5 = Literal.create(Seq(1, null, 8, 9, null),
ArrayType(IntegerType))
+ val a6 = Literal.create(Seq(true, false, false, true),
ArrayType(BooleanType))
+
+ checkEvaluation(ArrayRemove(a0, Literal(0)), Seq(1, 2, 3, 2, 2, 5))
+ checkEvaluation(ArrayRemove(a0, Literal(1)), Seq(2, 3, 2, 2, 5))
+ checkEvaluation(ArrayRemove(a0, Literal(2)), Seq(1, 3, 5))
+ checkEvaluation(ArrayRemove(a0, Literal(3)), Seq(1, 2, 2, 2, 5))
+ checkEvaluation(ArrayRemove(a0, Literal(5)), Seq(1, 2, 3, 2, 2))
+ checkEvaluation(ArrayRemove(a0, Literal(null, IntegerType)), null)
+
+ checkEvaluation(ArrayRemove(a1, Literal("")), Seq("b", "a", "a", "c",
"b"))
+ checkEvaluation(ArrayRemove(a1, Literal("a")), Seq("b", "c", "b"))
+ checkEvaluation(ArrayRemove(a1, Literal("b")), Seq("a", "a", "c"))
+ checkEvaluation(ArrayRemove(a1, Literal("c")), Seq("b", "a", "a", "b"))
+
+ checkEvaluation(ArrayRemove(a2, Literal("")), Seq(null, null))
+ checkEvaluation(ArrayRemove(a2, Literal(null, StringType)), null)
+
+ checkEvaluation(ArrayRemove(a3, Literal(1)), Seq.empty[Integer])
+
+ checkEvaluation(ArrayRemove(a4, Literal("a")), null)
+
+ checkEvaluation(ArrayRemove(a5, Literal(9)), Seq(1, null, 8, null))
+ checkEvaluation(ArrayRemove(a6, Literal(false)), Seq(true, true))
+
+ // complex data types
+ val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6),
Array[Byte](1, 2),
+ Array[Byte](1, 2), Array[Byte](5, 6)), ArrayType(BinaryType))
+ val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null),
+ ArrayType(BinaryType))
+ val b2 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)),
+ ArrayType(BinaryType))
+ val nullBinary = Literal.create(null, BinaryType)
+
+ val dataToRemoved1 = Literal.create(Array[Byte](5, 6), BinaryType)
+ checkEvaluation(ArrayRemove(b0, dataToRemoved1),
+ Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](1, 2)))
+ checkEvaluation(ArrayRemove(b0, nullBinary), null)
+ checkEvaluation(ArrayRemove(b1, dataToRemoved1),
Seq[Array[Byte]](Array[Byte](2, 1), null))
+ checkEvaluation(ArrayRemove(b2, dataToRemoved1),
Seq[Array[Byte]](null, Array[Byte](1, 2)))
+
+ val c0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)),
+ ArrayType(ArrayType(IntegerType)))
+ val c1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)),
+ ArrayType(ArrayType(IntegerType)))
--- End diff --
What if for `val c2 = Literal.create(Seq[Seq[Int]](null, Seq[Int](2, 1)),
ArrayType(ArrayType(IntegerType)))`?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]