This is an automated email from the ASF dual-hosted git repository.
ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 6bdae9d0ca7a [SPARK-55666][PYTHON][TEST] Fix flaky connect tests due
to non-deterministic row order
6bdae9d0ca7a is described below
commit 6bdae9d0ca7a94f34101a55dc2f7caf0b3f420d7
Author: Yicong-Huang <[email protected]>
AuthorDate: Wed Feb 25 13:38:03 2026 -0800
[SPARK-55666][PYTHON][TEST] Fix flaky connect tests due to
non-deterministic row order
### What changes were proposed in this pull request?
Use `assertDataFrameEqual` to avoid non-deterministic row order failures.
### Why are the changes needed?
Many tests use `assertEqual(cdf.collect(), sdf.collect())` to compare
results with bag semantic. Since bags do not guarantee row order, this can fail
non-deterministically (flaky test).
### Does this PR introduce any user-facing change?
No.
### How was this patch tested?
Existing test
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #54460 from
Yicong-Huang/SPARK-55666/fix/test-join-ambiguous-cols-row-order.
Authored-by: Yicong-Huang <[email protected]>
Signed-off-by: Takuya Ueshin <[email protected]>
---
.../sql/tests/connect/test_connect_basic.py | 21 ++++++++-------------
.../sql/tests/connect/test_connect_collection.py | 11 +++++------
.../sql/tests/connect/test_connect_column.py | 4 ++--
.../connect/test_connect_dataframe_property.py | 22 ++++++++--------------
.../sql/tests/connect/test_connect_function.py | 2 +-
5 files changed, 24 insertions(+), 36 deletions(-)
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 5e2a5bf97796..6ef0b36c5208 100755
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -36,6 +36,7 @@ from pyspark.sql.types import (
ArrayType,
Row,
)
+from pyspark.testing import assertDataFrameEqual
from pyspark.testing.utils import eventually
from pyspark.testing.connectutils import (
should_test_connect,
@@ -320,14 +321,12 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
cdf3 = cdf1.join(cdf2, cdf1["value"] == cdf2["value"])
sdf3 = sdf1.join(sdf2, sdf1["value"] == sdf2["value"])
- self.assertEqual(cdf3.schema, sdf3.schema)
- self.assertEqual(cdf3.collect(), sdf3.collect())
+ assertDataFrameEqual(cdf3, sdf3)
cdf4 = cdf1.join(cdf2, cdf1["value"].eqNullSafe(cdf2["value"]))
sdf4 = sdf1.join(sdf2, sdf1["value"].eqNullSafe(sdf2["value"]))
- self.assertEqual(cdf4.schema, sdf4.schema)
- self.assertEqual(cdf4.collect(), sdf4.collect())
+ assertDataFrameEqual(cdf4, sdf4)
cdf5 = cdf1.join(
cdf2, (cdf1["value"] == cdf2["value"]) &
(cdf1["value"].eqNullSafe(cdf2["value"]))
@@ -336,20 +335,17 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
sdf2, (sdf1["value"] == sdf2["value"]) &
(sdf1["value"].eqNullSafe(sdf2["value"]))
)
- self.assertEqual(cdf5.schema, sdf5.schema)
- self.assertEqual(cdf5.collect(), sdf5.collect())
+ assertDataFrameEqual(cdf5, sdf5)
cdf6 = cdf1.join(cdf2, cdf1["value"] ==
cdf2["value"]).select(cdf1.value)
sdf6 = sdf1.join(sdf2, sdf1["value"] ==
sdf2["value"]).select(sdf1.value)
- self.assertEqual(cdf6.schema, sdf6.schema)
- self.assertEqual(cdf6.collect(), sdf6.collect())
+ assertDataFrameEqual(cdf6, sdf6)
cdf7 = cdf1.join(cdf2, cdf1["value"] ==
cdf2["value"]).select(cdf2.value)
sdf7 = sdf1.join(sdf2, sdf1["value"] ==
sdf2["value"]).select(sdf2.value)
- self.assertEqual(cdf7.schema, sdf7.schema)
- self.assertEqual(cdf7.collect(), sdf7.collect())
+ assertDataFrameEqual(cdf7, sdf7)
def test_join_with_cte(self):
cte_query = "with dt as (select 1 as ida) select ida as id from dt"
@@ -362,8 +358,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
cdf2 = self.connect.sql(cte_query)
cdf3 = cdf1.join(cdf2, cdf1.id == cdf2.id)
- self.assertEqual(sdf3.schema, cdf3.schema)
- self.assertEqual(sdf3.collect(), cdf3.collect())
+ assertDataFrameEqual(cdf3, sdf3)
def test_with_columns_renamed(self):
# SPARK-41312: test DataFrame.withColumnsRenamed()
@@ -1621,7 +1616,7 @@ class SparkConnectGCTests(SparkConnectSQLTestCase):
)
# Execute the query, and assert the results are correct.
- self.assertEqual(cdf.collect(), sdf.collect())
+ assertDataFrameEqual(cdf, sdf)
# Verify the metadata of arrow batch chunks.
def split_into_batches(chunks):
diff --git a/python/pyspark/sql/tests/connect/test_connect_collection.py
b/python/pyspark/sql/tests/connect/test_connect_collection.py
index ad48c0508d72..0c19d13d4a3f 100644
--- a/python/pyspark/sql/tests/connect/test_connect_collection.py
+++ b/python/pyspark/sql/tests/connect/test_connect_collection.py
@@ -15,6 +15,7 @@
# limitations under the License.
#
+from pyspark.testing import assertDataFrameEqual
from pyspark.testing.connectutils import should_test_connect,
ReusedMixedTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils
@@ -62,13 +63,11 @@ class SparkConnectCollectionTests(ReusedMixedTestCase,
PandasOnSparkTestUtils):
cdf = self.connect.sql(query)
sdf = self.spark.sql(query)
- self.assertEqual(cdf.schema, sdf.schema)
+ assertDataFrameEqual(cdf, sdf)
- self.assertEqual(cdf.collect(), sdf.collect())
-
- self.assertEqual(
- cdf.select(CF.date_trunc("year",
cdf.date).alias("year")).collect(),
- sdf.select(SF.date_trunc("year",
sdf.date).alias("year")).collect(),
+ assertDataFrameEqual(
+ cdf.select(CF.date_trunc("year", cdf.date).alias("year")),
+ sdf.select(SF.date_trunc("year", sdf.date).alias("year")),
)
def test_head(self):
diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py
b/python/pyspark/sql/tests/connect/test_connect_column.py
index d1d0e9c86e7f..a048c26a63e5 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column.py
@@ -41,6 +41,7 @@ from pyspark.sql.types import (
BooleanType,
)
from pyspark.errors import PySparkTypeError, PySparkValueError
+from pyspark.testing import assertDataFrameEqual
from pyspark.testing.connectutils import should_test_connect,
ReusedMixedTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils
@@ -996,8 +997,7 @@ class SparkConnectColumnTests(ReusedMixedTestCase,
PandasOnSparkTestUtils):
sdf = self.spark.createDataFrame(data)
sdf1 = sdf.withColumn("a", sdf["a"].withField("b",
SF.lit(3))).select("a.b")
- self.assertEqual(cdf1.schema, sdf1.schema)
- self.assertEqual(cdf1.collect(), sdf1.collect())
+ assertDataFrameEqual(cdf1, sdf1)
def test_distributed_sequence_id(self):
cdf = self.connect.range(10)
diff --git
a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
index a01293c74b47..318bebf60b0d 100644
--- a/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
+++ b/python/pyspark/sql/tests/connect/test_connect_dataframe_property.py
@@ -28,6 +28,7 @@ from pyspark.sql.types import (
)
from pyspark.sql.utils import is_remote
from pyspark.sql import functions as SF
+from pyspark.testing import assertDataFrameEqual
from pyspark.testing.connectutils import should_test_connect,
ReusedMixedTestCase
from pyspark.testing.pandasutils import PandasOnSparkTestUtils
from pyspark.testing.utils import (
@@ -98,8 +99,7 @@ class SparkConnectDataFramePropertyTests(ReusedMixedTestCase,
PandasOnSparkTestU
sdf1 = sdf.to(schema)
- self.assertEqual(cdf1.schema, sdf1.schema)
- self.assertEqual(cdf1.collect(), sdf1.collect())
+ assertDataFrameEqual(cdf1, sdf1)
@unittest.skipIf(
not have_pandas or not have_pyarrow,
@@ -142,8 +142,7 @@ class
SparkConnectDataFramePropertyTests(ReusedMixedTestCase, PandasOnSparkTestU
self.assertFalse(is_remote())
sdf1 = sdf.mapInPandas(func, schema)
- self.assertEqual(cdf1.schema, sdf1.schema)
- self.assertEqual(cdf1.collect(), sdf1.collect())
+ assertDataFrameEqual(cdf1, sdf1)
@unittest.skipIf(
not have_pandas or not have_pyarrow,
@@ -176,8 +175,7 @@ class
SparkConnectDataFramePropertyTests(ReusedMixedTestCase, PandasOnSparkTestU
self.assertFalse(is_remote())
sdf1 = sdf.mapInArrow(func, schema)
- self.assertEqual(cdf1.schema, sdf1.schema)
- self.assertEqual(cdf1.collect(), sdf1.collect())
+ assertDataFrameEqual(cdf1, sdf1)
@unittest.skipIf(
not have_pandas or not have_pyarrow,
@@ -214,8 +212,7 @@ class
SparkConnectDataFramePropertyTests(ReusedMixedTestCase, PandasOnSparkTestU
self.assertFalse(is_remote())
sdf1 = sdf.groupby("id").applyInPandas(normalize, schema)
- self.assertEqual(cdf1.schema, sdf1.schema)
- self.assertEqual(cdf1.collect(), sdf1.collect())
+ assertDataFrameEqual(cdf1, sdf1)
@unittest.skipIf(
not have_pandas or not have_pyarrow,
@@ -247,8 +244,7 @@ class
SparkConnectDataFramePropertyTests(ReusedMixedTestCase, PandasOnSparkTestU
self.assertFalse(is_remote())
sdf1 = sdf.groupby("id").applyInArrow(normalize, schema)
- self.assertEqual(cdf1.schema, sdf1.schema)
- self.assertEqual(cdf1.collect(), sdf1.collect())
+ assertDataFrameEqual(cdf1, sdf1)
@unittest.skipIf(
not have_pandas or not have_pyarrow,
@@ -284,8 +280,7 @@ class
SparkConnectDataFramePropertyTests(ReusedMixedTestCase, PandasOnSparkTestU
self.assertFalse(is_remote())
sdf3 =
sdf1.groupby("id").cogroup(sdf2.groupby("id")).applyInPandas(asof_join, schema)
- self.assertEqual(cdf3.schema, sdf3.schema)
- self.assertEqual(cdf3.collect(), sdf3.collect())
+ assertDataFrameEqual(cdf3, sdf3)
@unittest.skipIf(
not have_pandas or not have_pyarrow,
@@ -324,8 +319,7 @@ class
SparkConnectDataFramePropertyTests(ReusedMixedTestCase, PandasOnSparkTestU
self.assertFalse(is_remote())
sdf3 =
sdf1.groupby("id").cogroup(sdf2.groupby("id")).applyInArrow(summarize, schema)
- self.assertEqual(cdf3.schema, sdf3.schema)
- self.assertEqual(cdf3.collect(), sdf3.collect())
+ assertDataFrameEqual(cdf3, sdf3)
def test_cached_schema_set_op(self):
data1 = [(1, 2, 3)]
diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py
b/python/pyspark/sql/tests/connect/test_connect_function.py
index de1584cb3bc3..e53ed6c70a70 100644
--- a/python/pyspark/sql/tests/connect/test_connect_function.py
+++ b/python/pyspark/sql/tests/connect/test_connect_function.py
@@ -1672,7 +1672,7 @@ class SparkConnectFunctionTests(ReusedMixedTestCase,
PandasOnSparkTestUtils):
# TODO: 'cdf.schema' has an extra metadata '{'__autoGeneratedAlias':
'true'}'
self.assertEqual(_drop_metadata(cdf.schema),
_drop_metadata(sdf.schema))
- self.assertEqual(cdf.collect(), sdf.collect())
+ assertDataFrameEqual(cdf, sdf)
def test_csv_functions(self):
query = """
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]