This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new dc73342db941 [SPARK-50436][PYTHON][TESTS] Use assertDataFrameEqual in 
pyspark.sql.tests.test_udf
dc73342db941 is described below

commit dc73342db941a7a202acacc2a7e90ff245192712
Author: Xinrong Meng <[email protected]>
AuthorDate: Mon Dec 2 08:42:08 2024 +0900

    [SPARK-50436][PYTHON][TESTS] Use assertDataFrameEqual in 
pyspark.sql.tests.test_udf
    
    ### What changes were proposed in this pull request?
    Use `assertDataFrameEqual` in pyspark.sql.tests.test_udf
    
    ### Why are the changes needed?
    `assertDataFrameEqual` is explicitly built to handle DataFrame-specific 
comparisons, including schema.
    
    So we propose to replace `assertEqual` with `assertDataFrameEqual`
    
    Part of https://issues.apache.org/jira/browse/SPARK-50435.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #49001 from xinrong-meng/impr_test_udf.
    
    Authored-by: Xinrong Meng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/tests/test_udf.py | 26 +++++++++++++-------------
 1 file changed, 13 insertions(+), 13 deletions(-)

diff --git a/python/pyspark/sql/tests/test_udf.py 
b/python/pyspark/sql/tests/test_udf.py
index 78aa2546128a..819391389237 100644
--- a/python/pyspark/sql/tests/test_udf.py
+++ b/python/pyspark/sql/tests/test_udf.py
@@ -220,7 +220,7 @@ class BaseUDFTestsMixin(object):
         right = self.spark.createDataFrame([Row(a=1)])
         df = left.join(right, on="a", how="left_outer")
         df = df.withColumn("b", udf(lambda x: "x")(df.a))
-        self.assertEqual(df.filter('b = "x"').collect(), [Row(a=1, b="x")])
+        assertDataFrameEqual(df.filter('b = "x"'), [Row(a=1, b="x")])
 
     def test_udf_in_filter_on_top_of_join(self):
         # regression test for SPARK-18589
@@ -228,7 +228,7 @@ class BaseUDFTestsMixin(object):
         right = self.spark.createDataFrame([Row(b=1)])
         f = udf(lambda a, b: a == b, BooleanType())
         df = left.crossJoin(right).filter(f("a", "b"))
-        self.assertEqual(df.collect(), [Row(a=1, b=1)])
+        assertDataFrameEqual(df, [Row(a=1, b=1)])
 
     def test_udf_in_join_condition(self):
         # regression test for SPARK-25314
@@ -243,7 +243,7 @@ class BaseUDFTestsMixin(object):
                 df.collect()
         with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
             df = left.join(right, f("a", "b"))
-            self.assertEqual(df.collect(), [Row(a=1, b=1)])
+            assertDataFrameEqual(df, [Row(a=1, b=1)])
 
     def test_udf_in_left_outer_join_condition(self):
         # regression test for SPARK-26147
@@ -256,7 +256,7 @@ class BaseUDFTestsMixin(object):
         # The Python UDF only refer to attributes from one side, so it's 
evaluable.
         df = left.join(right, f("a") == col("b").cast("string"), 
how="left_outer")
         with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
-            self.assertEqual(df.collect(), [Row(a=1, b=1)])
+            assertDataFrameEqual(df, [Row(a=1, b=1)])
 
     def test_udf_and_common_filter_in_join_condition(self):
         # regression test for SPARK-25314
@@ -266,7 +266,7 @@ class BaseUDFTestsMixin(object):
         f = udf(lambda a, b: a == b, BooleanType())
         df = left.join(right, [f("a", "b"), left.a1 == right.b1])
         # do not need spark.sql.crossJoin.enabled=true for udf is not the only 
join condition.
-        self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])
+        assertDataFrameEqual(df, [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])
 
     def test_udf_not_supported_in_join_condition(self):
         # regression test for SPARK-25314
@@ -294,7 +294,7 @@ class BaseUDFTestsMixin(object):
         f = udf(lambda a: a, IntegerType())
 
         df = left.join(right, [f("a") == f("b"), left.a1 == right.b1])
-        self.assertEqual(df.collect(), [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])
+        assertDataFrameEqual(df, [Row(a=1, a1=1, a2=1, b=1, b1=1, b2=1)])
 
     def test_udf_without_arguments(self):
         self.spark.catalog.registerFunction("foo", lambda: "bar")
@@ -331,7 +331,7 @@ class BaseUDFTestsMixin(object):
 
         my_filter = udf(lambda a: a < 2, BooleanType())
         sel = df.select(col("key"), 
col("value")).filter((my_filter(col("key"))) & (df.value < "2"))
-        self.assertEqual(sel.collect(), [Row(key=1, value="1")])
+        assertDataFrameEqual(sel, [Row(key=1, value="1")])
 
     def test_udf_with_variant_input(self):
         df = self.spark.range(0, 10).selectExpr("parse_json(cast(id as 
string)) v")
@@ -461,7 +461,7 @@ class BaseUDFTestsMixin(object):
 
         my_filter = udf(lambda a: a == 1, BooleanType())
         sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
-        self.assertEqual(sel.collect(), [Row(key=1)])
+        assertDataFrameEqual(sel, [Row(key=1)])
 
         my_copy = udf(lambda x: x, IntegerType())
         my_add = udf(lambda a, b: int(a + b), IntegerType())
@@ -471,7 +471,7 @@ class BaseUDFTestsMixin(object):
             .agg(sum(my_strlen(col("value"))).alias("s"))
             .select(my_add(col("k"), col("s")).alias("t"))
         )
-        self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])
+        assertDataFrameEqual(sel, [Row(t=4), Row(t=3)])
 
     def test_udf_in_generate(self):
         from pyspark.sql.functions import explode
@@ -505,7 +505,7 @@ class BaseUDFTestsMixin(object):
         my_copy = udf(lambda x: x, IntegerType())
         df = self.spark.range(10).orderBy("id")
         res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1)
-        self.assertEqual(res.collect(), [Row(id=0, copy=0)])
+        assertDataFrameEqual(res, [Row(id=0, copy=0)])
 
     def test_udf_registration_returns_udf(self):
         df = self.spark.range(10)
@@ -838,12 +838,12 @@ class BaseUDFTestsMixin(object):
             for df in [filesource_df, datasource_df, datasource_v2_df]:
                 result = df.withColumn("c", c1)
                 expected = df.withColumn("c", lit(2))
-                self.assertEqual(expected.collect(), result.collect())
+                assertDataFrameEqual(expected, result)
 
             for df in [filesource_df, datasource_df, datasource_v2_df]:
                 result = df.withColumn("c", c2)
                 expected = df.withColumn("c", col("i") + 1)
-                self.assertEqual(expected.collect(), result.collect())
+                assertDataFrameEqual(expected, result)
 
             for df in [filesource_df, datasource_df, datasource_v2_df]:
                 for f in [f1, f2]:
@@ -902,7 +902,7 @@ class BaseUDFTestsMixin(object):
             result = self.spark.sql(
                 "select i from values(0L) as data(i) where i in (select id 
from v)"
             )
-            self.assertEqual(result.collect(), [Row(i=0)])
+            assertDataFrameEqual(result, [Row(i=0)])
 
     def test_udf_globals_not_overwritten(self):
         @udf("string")


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to