zhengruifeng commented on code in PR #38961:
URL: https://github.com/apache/spark/pull/38961#discussion_r1041847513


##########
python/pyspark/sql/tests/connect/test_connect_function.py:
##########
@@ -413,6 +431,144 @@ def test_aggregation_functions(self):
             sdf.groupBy("a").agg(SF.percentile_approx(sdf.b, [0.1, 
0.9])).toPandas(),
         )
 
+    def test_collection_functions(self):
+        from pyspark.sql import functions as SF
+        from pyspark.sql.connect import functions as CF
+
+        query = """
+            SELECT * FROM VALUES
+            (ARRAY('a', 'ab'), ARRAY(1, 2, 3), ARRAY(1, NULL, 3), 1, 2, 'a'),
+            (ARRAY('x', NULL), NULL, ARRAY(1, 3), 3, 4, 'x'),
+            (NULL, ARRAY(-1, -2, -3), Array(), 5, 6, NULL)
+            AS tab(a, b, c, d, e, f)
+            """
+        # +---------+------------+------------+---+---+----+
+        # |        a|           b|           c|  d|  e|   f|
+        # +---------+------------+------------+---+---+----+
+        # |  [a, ab]|   [1, 2, 3]|[1, null, 3]|  1|  2|   a|
+        # |[x, null]|        null|      [1, 3]|  3|  4|   x|
+        # |     null|[-1, -2, -3]|          []|  5|  6|null|
+        # +---------+------------+------------+---+---+----+
+
+        cdf = self.connect.sql(query)
+        sdf = self.spark.sql(query)
+
+        for cfunc, sfunc in [
+            (CF.array_distinct, SF.array_distinct),
+            (CF.array_max, SF.array_max),
+            (CF.array_min, SF.array_min),
+        ]:
+            self.assert_eq(
+                cdf.select(cfunc("a"), cfunc(cdf.b)).toPandas(),
+                sdf.select(sfunc("a"), sfunc(sdf.b)).toPandas(),
+            )
+
+        for cfunc, sfunc in [
+            (CF.array_except, SF.array_except),
+            (CF.array_intersect, SF.array_intersect),
+            (CF.array_union, SF.array_union),
+            (CF.arrays_overlap, SF.arrays_overlap),
+        ]:
+            self.assert_eq(
+                cdf.select(cfunc("b", cdf.c)).toPandas(),
+                sdf.select(sfunc("b", sdf.c)).toPandas(),
+            )
+
+        for cfunc, sfunc in [
+            (CF.array_position, SF.array_position),
+            (CF.array_remove, SF.array_remove),
+        ]:
+            self.assert_eq(
+                cdf.select(cfunc(cdf.a, "ab")).toPandas(),
+                sdf.select(sfunc(sdf.a, "ab")).toPandas(),
+            )
+
+        # test array
+        self.assert_eq(
+            cdf.select(CF.array(cdf.d, "e")).toPandas(),
+            sdf.select(SF.array(sdf.d, "e")).toPandas(),
+        )
+        self.assert_eq(
+            cdf.select(CF.array(cdf.d, "e", CF.lit(99))).toPandas(),
+            sdf.select(SF.array(sdf.d, "e", SF.lit(99))).toPandas(),
+        )
+
+        # test array_contains
+        self.assert_eq(
+            cdf.select(CF.array_contains(cdf.a, "ab")).toPandas(),
+            sdf.select(SF.array_contains(sdf.a, "ab")).toPandas(),
+        )
+        self.assert_eq(
+            cdf.select(CF.array_contains(cdf.a, cdf.f)).toPandas(),
+            sdf.select(SF.array_contains(sdf.a, sdf.f)).toPandas(),
+        )
+
+        # test array_join
+        self.assert_eq(
+            cdf.select(
+                CF.array_join(cdf.a, ","), CF.array_join("b", ":"), 
CF.array_join("c", "~")
+            ).toPandas(),
+            sdf.select(
+                SF.array_join(sdf.a, ","), SF.array_join("b", ":"), 
SF.array_join("c", "~")
+            ).toPandas(),
+        )
+        self.assert_eq(
+            cdf.select(
+                CF.array_join(cdf.a, ",", "_null_"),
+                CF.array_join("b", ":", ".null."),
+                CF.array_join("c", "~", "NULL"),
+            ).toPandas(),
+            sdf.select(
+                SF.array_join(sdf.a, ",", "_null_"),
+                SF.array_join("b", ":", ".null."),
+                SF.array_join("c", "~", "NULL"),
+            ).toPandas(),
+        )
+
+        # test array_repeat
+        self.assert_eq(
+            cdf.select(CF.array_repeat(cdf.f, "d")).toPandas(),
+            sdf.select(SF.array_repeat(sdf.f, "d")).toPandas(),
+        )
+        self.assert_eq(
+            cdf.select(CF.array_repeat("f", cdf.d)).toPandas(),
+            sdf.select(SF.array_repeat("f", sdf.d)).toPandas(),
+        )
+        # TODO: Make Literal contains DataType
+        #   Cannot resolve "array_repeat(f, 3)" due to data type mismatch:

Review Comment:
   I think we may need to revisit 
https://github.com/apache/spark/pull/38800#discussion_r1033281998 at some point,
   we'd better specify the `Datatype` of `3` to IntegerType instead of LongType.
   
   or we can use `Cast` as a workaround.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to