This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new be12eb732b0b [SPARK-50465][PYTHON][TESTS] Use assertDataFrameEqual in
pyspark.sql.tests.test_group and test_readwriter
be12eb732b0b is described below
commit be12eb732b0b3deedb7cf8327c61f9308cda54e0
Author: Xinrong Meng <[email protected]>
AuthorDate: Mon Dec 2 11:43:33 2024 -0800
[SPARK-50465][PYTHON][TESTS] Use assertDataFrameEqual in
pyspark.sql.tests.test_group and test_readwriter
### What changes were proposed in this pull request?
Use `assertDataFrameEqual` in pyspark.sql.tests.test_group and
test_readwriter
### Why are the changes needed?
`assertDataFrameEqual` is explicitly built to handle DataFrame-specific
comparisons, including schema.
So we propose to replace `assertEqual` with `assertDataFrameEqual`
Part of https://issues.apache.org/jira/browse/SPARK-50435.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #49023 from xinrong-meng/impr_test_group.
Authored-by: Xinrong Meng <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
python/pyspark/sql/tests/test_group.py | 12 ++++++------
python/pyspark/sql/tests/test_readwriter.py | 21 +++++++++++----------
2 files changed, 17 insertions(+), 16 deletions(-)
diff --git a/python/pyspark/sql/tests/test_group.py
b/python/pyspark/sql/tests/test_group.py
index 8e3d2d8d0003..bbc089b00c13 100644
--- a/python/pyspark/sql/tests/test_group.py
+++ b/python/pyspark/sql/tests/test_group.py
@@ -36,11 +36,11 @@ class GroupTestsMixin:
data = [Row(key=1, value=10), Row(key=1, value=20), Row(key=1,
value=30)]
df = self.spark.createDataFrame(data)
g = df.groupBy("key")
- self.assertEqual(g.max("value").collect(), [Row(**{"key": 1,
"max(value)": 30})])
- self.assertEqual(g.min("value").collect(), [Row(**{"key": 1,
"min(value)": 10})])
- self.assertEqual(g.sum("value").collect(), [Row(**{"key": 1,
"sum(value)": 60})])
- self.assertEqual(g.count().collect(), [Row(key=1, count=3)])
- self.assertEqual(g.mean("value").collect(), [Row(**{"key": 1,
"avg(value)": 20.0})])
+ assertDataFrameEqual(g.max("value"), [Row(**{"key": 1, "max(value)":
30})])
+ assertDataFrameEqual(g.min("value"), [Row(**{"key": 1, "min(value)":
10})])
+ assertDataFrameEqual(g.sum("value"), [Row(**{"key": 1, "sum(value)":
60})])
+ assertDataFrameEqual(g.count(), [Row(key=1, count=3)])
+ assertDataFrameEqual(g.mean("value"), [Row(**{"key": 1, "avg(value)":
20.0})])
data = [
Row(electronic="Smartphone", year=2018, sales=150000),
@@ -59,7 +59,7 @@ class GroupTestsMixin:
df = self.df
g = df.groupBy()
self.assertEqual([99, 100], sorted(g.agg({"key": "max", "value":
"count"}).collect()[0]))
- self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
+ assertDataFrameEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
from pyspark.sql import functions
diff --git a/python/pyspark/sql/tests/test_readwriter.py
b/python/pyspark/sql/tests/test_readwriter.py
index 2fca6b57decf..683c925eefc2 100644
--- a/python/pyspark/sql/tests/test_readwriter.py
+++ b/python/pyspark/sql/tests/test_readwriter.py
@@ -23,6 +23,7 @@ from pyspark.errors import AnalysisException
from pyspark.sql.functions import col, lit
from pyspark.sql.readwriter import DataFrameWriterV2
from pyspark.sql.types import StructType, StructField, StringType
+from pyspark.testing import assertDataFrameEqual
from pyspark.testing.sqlutils import ReusedSQLTestCase
@@ -34,15 +35,15 @@ class ReadwriterTestsMixin:
try:
df.write.json(tmpPath)
actual = self.spark.read.json(tmpPath)
- self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+ assertDataFrameEqual(df, actual)
schema = StructType([StructField("value", StringType(), True)])
actual = self.spark.read.json(tmpPath, schema)
- self.assertEqual(sorted(df.select("value").collect()),
sorted(actual.collect()))
+ assertDataFrameEqual(df.select("value"), actual)
df.write.json(tmpPath, "overwrite")
actual = self.spark.read.json(tmpPath)
- self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+ assertDataFrameEqual(df, actual)
df.write.save(
format="json",
@@ -53,11 +54,11 @@ class ReadwriterTestsMixin:
actual = self.spark.read.load(
format="json", path=tmpPath, noUse="this options will not be
used in load."
)
- self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+ assertDataFrameEqual(df, actual)
with self.sql_conf({"spark.sql.sources.default":
"org.apache.spark.sql.json"}):
actual = self.spark.read.load(path=tmpPath)
- self.assertEqual(sorted(df.collect()),
sorted(actual.collect()))
+ assertDataFrameEqual(df, actual)
csvpath = os.path.join(tempfile.mkdtemp(), "data")
df.write.option("quote", None).format("csv").save(csvpath)
@@ -71,15 +72,15 @@ class ReadwriterTestsMixin:
try:
df.write.json(tmpPath)
actual = self.spark.read.json(tmpPath)
- self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+ assertDataFrameEqual(df, actual)
schema = StructType([StructField("value", StringType(), True)])
actual = self.spark.read.json(tmpPath, schema)
- self.assertEqual(sorted(df.select("value").collect()),
sorted(actual.collect()))
+ assertDataFrameEqual(df.select("value"), actual)
df.write.mode("overwrite").json(tmpPath)
actual = self.spark.read.json(tmpPath)
- self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+ assertDataFrameEqual(df, actual)
df.write.mode("overwrite").options(
noUse="this options will not be used in save."
@@ -89,11 +90,11 @@ class ReadwriterTestsMixin:
actual = self.spark.read.format("json").load(
path=tmpPath, noUse="this options will not be used in load."
)
- self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
+ assertDataFrameEqual(df, actual)
with self.sql_conf({"spark.sql.sources.default":
"org.apache.spark.sql.json"}):
actual = self.spark.read.load(path=tmpPath)
- self.assertEqual(sorted(df.collect()),
sorted(actual.collect()))
+ assertDataFrameEqual(df, actual)
finally:
shutil.rmtree(tmpPath)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]