This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new be12eb732b0b [SPARK-50465][PYTHON][TESTS] Use assertDataFrameEqual in 
pyspark.sql.tests.test_group and test_readwriter
be12eb732b0b is described below

commit be12eb732b0b3deedb7cf8327c61f9308cda54e0
Author: Xinrong Meng <[email protected]>
AuthorDate: Mon Dec 2 11:43:33 2024 -0800

    [SPARK-50465][PYTHON][TESTS] Use assertDataFrameEqual in 
pyspark.sql.tests.test_group and test_readwriter
    
    ### What changes were proposed in this pull request?
    Use `assertDataFrameEqual` in pyspark.sql.tests.test_group and 
test_readwriter
    
    ### Why are the changes needed?
    `assertDataFrameEqual` is explicitly built to handle DataFrame-specific 
comparisons, including schema.
    
    So we propose to replace `assertEqual` with `assertDataFrameEqual`
    
    Part of https://issues.apache.org/jira/browse/SPARK-50435.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    Existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #49023 from xinrong-meng/impr_test_group.
    
    Authored-by: Xinrong Meng <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 python/pyspark/sql/tests/test_group.py      | 12 ++++++------
 python/pyspark/sql/tests/test_readwriter.py | 21 +++++++++++----------
 2 files changed, 17 insertions(+), 16 deletions(-)

diff --git a/python/pyspark/sql/tests/test_group.py 
b/python/pyspark/sql/tests/test_group.py
index 8e3d2d8d0003..bbc089b00c13 100644
--- a/python/pyspark/sql/tests/test_group.py
+++ b/python/pyspark/sql/tests/test_group.py
@@ -36,11 +36,11 @@ class GroupTestsMixin:
         data = [Row(key=1, value=10), Row(key=1, value=20), Row(key=1, 
value=30)]
         df = self.spark.createDataFrame(data)
         g = df.groupBy("key")
-        self.assertEqual(g.max("value").collect(), [Row(**{"key": 1, 
"max(value)": 30})])
-        self.assertEqual(g.min("value").collect(), [Row(**{"key": 1, 
"min(value)": 10})])
-        self.assertEqual(g.sum("value").collect(), [Row(**{"key": 1, 
"sum(value)": 60})])
-        self.assertEqual(g.count().collect(), [Row(key=1, count=3)])
-        self.assertEqual(g.mean("value").collect(), [Row(**{"key": 1, 
"avg(value)": 20.0})])
+        assertDataFrameEqual(g.max("value"), [Row(**{"key": 1, "max(value)": 
30})])
+        assertDataFrameEqual(g.min("value"), [Row(**{"key": 1, "min(value)": 
10})])
+        assertDataFrameEqual(g.sum("value"), [Row(**{"key": 1, "sum(value)": 
60})])
+        assertDataFrameEqual(g.count(), [Row(key=1, count=3)])
+        assertDataFrameEqual(g.mean("value"), [Row(**{"key": 1, "avg(value)": 
20.0})])
 
         data = [
             Row(electronic="Smartphone", year=2018, sales=150000),
@@ -59,7 +59,7 @@ class GroupTestsMixin:
         df = self.df
         g = df.groupBy()
         self.assertEqual([99, 100], sorted(g.agg({"key": "max", "value": 
"count"}).collect()[0]))
-        self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
+        assertDataFrameEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
 
         from pyspark.sql import functions
 
diff --git a/python/pyspark/sql/tests/test_readwriter.py 
b/python/pyspark/sql/tests/test_readwriter.py
index 2fca6b57decf..683c925eefc2 100644
--- a/python/pyspark/sql/tests/test_readwriter.py
+++ b/python/pyspark/sql/tests/test_readwriter.py
@@ -23,6 +23,7 @@ from pyspark.errors import AnalysisException
 from pyspark.sql.functions import col, lit
 from pyspark.sql.readwriter import DataFrameWriterV2
 from pyspark.sql.types import StructType, StructField, StringType
+from pyspark.testing import assertDataFrameEqual
 from pyspark.testing.sqlutils import ReusedSQLTestCase
 
 
@@ -34,15 +35,15 @@ class ReadwriterTestsMixin:
         try:
             df.write.json(tmpPath)
             actual = self.spark.read.json(tmpPath)
-            self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+            assertDataFrameEqual(df, actual)
 
             schema = StructType([StructField("value", StringType(), True)])
             actual = self.spark.read.json(tmpPath, schema)
-            self.assertEqual(sorted(df.select("value").collect()), 
sorted(actual.collect()))
+            assertDataFrameEqual(df.select("value"), actual)
 
             df.write.json(tmpPath, "overwrite")
             actual = self.spark.read.json(tmpPath)
-            self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+            assertDataFrameEqual(df, actual)
 
             df.write.save(
                 format="json",
@@ -53,11 +54,11 @@ class ReadwriterTestsMixin:
             actual = self.spark.read.load(
                 format="json", path=tmpPath, noUse="this options will not be 
used in load."
             )
-            self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+            assertDataFrameEqual(df, actual)
 
             with self.sql_conf({"spark.sql.sources.default": 
"org.apache.spark.sql.json"}):
                 actual = self.spark.read.load(path=tmpPath)
-                self.assertEqual(sorted(df.collect()), 
sorted(actual.collect()))
+                assertDataFrameEqual(df, actual)
 
             csvpath = os.path.join(tempfile.mkdtemp(), "data")
             df.write.option("quote", None).format("csv").save(csvpath)
@@ -71,15 +72,15 @@ class ReadwriterTestsMixin:
         try:
             df.write.json(tmpPath)
             actual = self.spark.read.json(tmpPath)
-            self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+            assertDataFrameEqual(df, actual)
 
             schema = StructType([StructField("value", StringType(), True)])
             actual = self.spark.read.json(tmpPath, schema)
-            self.assertEqual(sorted(df.select("value").collect()), 
sorted(actual.collect()))
+            assertDataFrameEqual(df.select("value"), actual)
 
             df.write.mode("overwrite").json(tmpPath)
             actual = self.spark.read.json(tmpPath)
-            self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+            assertDataFrameEqual(df, actual)
 
             df.write.mode("overwrite").options(
                 noUse="this options will not be used in save."
@@ -89,11 +90,11 @@ class ReadwriterTestsMixin:
             actual = self.spark.read.format("json").load(
                 path=tmpPath, noUse="this options will not be used in load."
             )
-            self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+            assertDataFrameEqual(df, actual)
 
             with self.sql_conf({"spark.sql.sources.default": 
"org.apache.spark.sql.json"}):
                 actual = self.spark.read.load(path=tmpPath)
-                self.assertEqual(sorted(df.collect()), 
sorted(actual.collect()))
+                assertDataFrameEqual(df, actual)
         finally:
             shutil.rmtree(tmpPath)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to