This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 4af4ddea116 [SPARK-45552][PS] Introduce flexible parameters to
`assertDataFrameEqual`
4af4ddea116 is described below
commit 4af4ddea116d26086550596693ce09674e75bfa3
Author: Haejoon Lee <[email protected]>
AuthorDate: Mon Oct 30 11:07:01 2023 +0900
[SPARK-45552][PS] Introduce flexible parameters to `assertDataFrameEqual`
### What changes were proposed in this pull request?
This PR proposes to add six new parameters to the `assertDataFrameEqual`:
`ignoreNullable`, `ignoreColumnOrder`, `ignoreColumnName`, `ignoreColumnType`,
`maxErrors`, and `showOnlyDiff` to provide users with more flexibility in
DataFrame testing.
### Why are the changes needed?
To enhance the utility of `assertDataFrameEqual` by accommodating various
common DataFrame comparison scenarios that users might encounter, without
necessitating manual adjustments or workarounds.
### Does this PR introduce _any_ user-facing change?
Yes. `assertDataFrameEqual` now have the option to use the six new
parameters:
<!DOCTYPE html>
Parameter | Type | Comment
-- | -- | --
ignoreNullable | Boolean [optional] | Specifies whether a column’s nullable
property is included when checking for schema equality.</br></br> When set to
True (default), the nullable property of the columns being compared is not
taken into account and the columns will be considered equal even if they have
different nullable settings.</br></br>When set to False, columns are considered
equal only if they have the same nullable setting.
ignoreColumnOrder | Boolean [optional] | Specifies whether to compare
columns in the order they appear in the DataFrames or by column name.</br></br>
When set to False (default), columns are compared in the order they appear in
the DataFrames.</br></br> When set to True, a column in the expected DataFrame
is compared to the column with the same name in the actual DataFrame.
</br></br>ignoreColumnOrder cannot be set to True if ignoreColumnNames is also
set to True.
ignoreColumnName | Boolean [optional] | Specifies whether to fail the
initial schema equality check if the column names in the two DataFrames are
different.</br></br> When set to False (default), column names are checked and
the function fails if they are different.</br></br> When set to True, the
function will succeed even if column names are different. Column data types are
compared for columns in the order they appear in the DataFrames.</br></br>
ignoreColumnNames cannot be set to [...]
ignoreColumnType | Boolean [optional] | Specifies whether to ignore the
data type of the columns when comparing.</br></br> When set to False (default),
column data types are checked and the function fails if they are
different.</br></br> When set to True, the schema equality check will succeed
even if column data types are different and the function will attempt to
compare rows.
maxErrors | Integer [optional] | The maximum number of row comparison
failures to encounter before returning.</br></br> When this number of row
comparisons have failed, the function returns independent of how many rows have
been compared.</br></br> Set to None by default which means compare all rows
independent of number of failures.
showOnlyDiff | Boolean [optional] | If set to True, the error message will
only include rows that are different.</br></br> If set to False (default), the
error message will include all rows (when there is at least one row that is
different).
### How was this patch tested?
Added usage examples into doctest for each parameter.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #43433 from itholic/SPARK-45552.
Authored-by: Haejoon Lee <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/tests/test_utils.py | 68 +++++++++++
python/pyspark/testing/utils.py | 215 +++++++++++++++++++++++++++++++--
2 files changed, 274 insertions(+), 9 deletions(-)
diff --git a/python/pyspark/sql/tests/test_utils.py
b/python/pyspark/sql/tests/test_utils.py
index a2cad4e83bd..421043a41bb 100644
--- a/python/pyspark/sql/tests/test_utils.py
+++ b/python/pyspark/sql/tests/test_utils.py
@@ -1238,6 +1238,9 @@ class UtilsTestsMixin:
assertDataFrameEqual(df1, df2)
+ with self.assertRaises(PySparkAssertionError):
+ assertDataFrameEqual(df1, df2, ignoreNullable=False)
+
def test_schema_ignore_nullable_array_equal(self):
s1 = StructType([StructField("names", ArrayType(DoubleType(), True),
True)])
s2 = StructType([StructField("names", ArrayType(DoubleType(), False),
False)])
@@ -1611,6 +1614,71 @@ class UtilsTestsMixin:
message_parameters={"error_msg": error_msg},
)
+ def test_dataframe_ignore_column_order(self):
+ df1 = self.spark.createDataFrame([Row(A=1, B=2), Row(A=3, B=4)])
+ df2 = self.spark.createDataFrame([Row(B=2, A=1), Row(B=4, A=3)])
+
+ with self.assertRaises(PySparkAssertionError):
+ assertDataFrameEqual(df1, df2, ignoreColumnOrder=False)
+
+ assertDataFrameEqual(df1, df2, ignoreColumnOrder=True)
+
+ def test_dataframe_ignore_column_name(self):
+ df1 = self.spark.createDataFrame([(1, 2), (3, 4)], ["A", "B"])
+ df2 = self.spark.createDataFrame([(1, 2), (3, 4)], ["X", "Y"])
+
+ with self.assertRaises(PySparkAssertionError):
+ assertDataFrameEqual(df1, df2, ignoreColumnName=False)
+
+ assertDataFrameEqual(df1, df2, ignoreColumnName=True)
+
+ def test_dataframe_ignore_column_type(self):
+ df1 = self.spark.createDataFrame([(1, "2"), (3, "4")], ["A", "B"])
+ df2 = self.spark.createDataFrame([(1, 2), (3, 4)], ["A", "B"])
+
+ with self.assertRaises(PySparkAssertionError):
+ assertDataFrameEqual(df1, df2, ignoreColumnType=False)
+
+ assertDataFrameEqual(df1, df2, ignoreColumnType=True)
+
+ def test_dataframe_max_errors(self):
+ df1 = self.spark.createDataFrame([(1, "a"), (2, "b"), (3, "c"), (4,
"d")], ["id", "value"])
+ df2 = self.spark.createDataFrame([(1, "a"), (2, "z"), (3, "x"), (4,
"y")], ["id", "value"])
+
+ # We expect differences in rows 2, 3, and 4.
+ # Setting maxErrors to 2 will limit the reported errors.
+ maxErrors = 2
+ with self.assertRaises(PySparkAssertionError) as context:
+ assertDataFrameEqual(df1, df2, maxErrors=maxErrors)
+
+ # Check if the error message contains information about 2 mismatches
only.
+ error_message = str(context.exception)
+ self.assertTrue("! Row" in error_message and error_message.count("!
Row") == maxErrors * 2)
+
+ def test_dataframe_show_only_diff(self):
+ df1 = self.spark.createDataFrame(
+ [(1, "apple", "red"), (2, "banana", "yellow"), (3, "cherry",
"red")],
+ ["id", "fruit", "color"],
+ )
+ df2 = self.spark.createDataFrame(
+ [(1, "apple", "green"), (2, "banana", "yellow"), (3, "cherry",
"blue")],
+ ["id", "fruit", "color"],
+ )
+
+ with self.assertRaises(PySparkAssertionError) as context:
+ assertDataFrameEqual(df1, df2, showOnlyDiff=False)
+
+ error_message = str(context.exception)
+
+ self.assertTrue("apple" in error_message and "banana" in error_message)
+
+ with self.assertRaises(PySparkAssertionError) as context:
+ assertDataFrameEqual(df1, df2, showOnlyDiff=True)
+
+ error_message = str(context.exception)
+
+ self.assertTrue("apple" in error_message and "banana" not in
error_message)
+
class UtilsTests(ReusedSQLTestCase, UtilsTestsMixin):
def test_capture_analysis_exception(self):
diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py
index 5ee27862923..282f4cc1cf5 100644
--- a/python/pyspark/testing/utils.py
+++ b/python/pyspark/testing/utils.py
@@ -22,6 +22,7 @@ import sys
import unittest
import difflib
import functools
+import math
from decimal import Decimal
from time import time, sleep
from typing import (
@@ -57,6 +58,7 @@ from pyspark.find_spark_home import _find_spark_home
from pyspark.sql.dataframe import DataFrame
from pyspark.sql import Row
from pyspark.sql.types import StructType, StructField
+from pyspark.sql.functions import col, when
__all__ = ["assertDataFrameEqual", "assertSchemaEqual"]
@@ -396,6 +398,12 @@ def assertDataFrameEqual(
checkRowOrder: bool = False,
rtol: float = 1e-5,
atol: float = 1e-8,
+ ignoreNullable: bool = True,
+ ignoreColumnOrder: bool = False,
+ ignoreColumnName: bool = False,
+ ignoreColumnType: bool = False,
+ maxErrors: Optional[int] = None,
+ showOnlyDiff: bool = False,
):
r"""
A util function to assert equality between `actual` and `expected`
@@ -424,6 +432,55 @@ def assertDataFrameEqual(
atol : float, optional
The absolute tolerance, used in asserting approximate equality for
float values in actual
and expected. Set to 1e-8 by default. (See Notes)
+ ignoreNullable : bool, default True
+ Specifies whether a column’s nullable property is included when
checking for
+ schema equality.
+ When set to `True` (default), the nullable property of the columns
being compared
+ is not taken into account and the columns will be considered equal
even if they have
+ different nullable settings.
+ When set to `False`, columns are considered equal only if they have
the same nullable
+ setting.
+
+ .. versionadded:: 4.0.0
+ ignoreColumnOrder : bool, default False
+ Specifies whether to compare columns in the order they appear in the
DataFrame or by
+ column name.
+ If set to `False` (default), columns are compared in the order they
appear in the
+ DataFrames.
+ When set to `True`, a column in the expected DataFrame is compared to
the column with the
+ same name in the actual DataFrame.
+
+ .. versionadded:: 4.0.0
+ ignoreColumnName : bool, default False
+ Specifies whether to fail the initial schema equality check if the
column names in the two
+ DataFrames are different.
+ When set to `False` (default), column names are checked and the
function fails if they are
+ different.
+ When set to `True`, the function will succeed even if column names are
different.
+ Column data types are compared for columns in the order they appear in
the DataFrames.
+
+ .. versionadded:: 4.0.0
+ ignoreColumnType : bool, default False
+ Specifies whether to ignore the data type of the columns when
comparing.
+ When set to `False` (default), column data types are checked and the
function fails if they
+ are different.
+ When set to `True`, the schema equality check will succeed even if
column data types are
+ different and the function will attempt to compare rows.
+
+ .. versionadded:: 4.0.0
+ maxErrors : bool, optional
+ The maximum number of row comparison failures to encounter before
returning.
+ When this number of row comparisons have failed, the function returns
independent of how
+ many rows have been compared.
+ Set to None by default which means compare all rows independent of
number of failures.
+
+ .. versionadded:: 4.0.0
+ showOnlyDiff : bool, default False
+ If set to `True`, the error message will only include rows that are
different.
+ If set to `False` (default), the error message will include all rows
+ (when there is at least one row that is different).
+
+ .. versionadded:: 4.0.0
Notes
-----
@@ -440,6 +497,9 @@ def assertDataFrameEqual(
``absolute(a - b) <= (atol + rtol * absolute(b))``.
+ `ignoreColumnOrder` cannot be set to `True` if `ignoreColumnNames` is also
set to `True`.
+ `ignoreColumnNames` cannot be set to `True` if `ignoreColumnOrder` is also
set to `True`.
+
Examples
--------
>>> df1 = spark.createDataFrame(data=[("1", 1000), ("2", 3000)],
schema=["id", "amount"])
@@ -469,12 +529,101 @@ def assertDataFrameEqual(
PySparkAssertionError: [DIFFERENT_ROWS] Results do not match: ( 66.66667 %
)
*** actual ***
! Row(id='1', amount=1000.0)
- Row(id='2', amount=3000.0)
+ Row(id='2', amount=3000.0)
! Row(id='3', amount=2000.0)
*** expected ***
! Row(id='1', amount=1001.0)
- Row(id='2', amount=3000.0)
+ Row(id='2', amount=3000.0)
! Row(id='3', amount=2003.0)
+
+ Example for ignoreNullable
+
+ >>> from pyspark.sql.types import StructType, StructField, StringType,
LongType
+ >>> df1_nullable = spark.createDataFrame(
+ ... data=[(1000, "1"), (5000, "2")],
+ ... schema=StructType(
+ ... [StructField("amount", LongType(), True), StructField("id",
StringType(), True)]
+ ... )
+ ... )
+ >>> df2_nullable = spark.createDataFrame(
+ ... data=[(1000, "1"), (5000, "2")],
+ ... schema=StructType(
+ ... [StructField("amount", LongType(), True), StructField("id",
StringType(), False)]
+ ... )
+ ... )
+ >>> assertDataFrameEqual(df1_nullable, df2_nullable, ignoreNullable=True)
# pass
+ >>> assertDataFrameEqual(
+ ... df1_nullable, df2_nullable, ignoreNullable=False
+ ... ) # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ PySparkAssertionError: [DIFFERENT_SCHEMA] Schemas do not match.
+ --- actual
+ +++ expected
+ - StructType([StructField('amount', LongType(), True), StructField('id',
StringType(), True)])
+ ?
^^^
+ + StructType([StructField('amount', LongType(), True), StructField('id',
StringType(), False)])
+ ?
^^^^
+
+ Example for ignoreColumnOrder
+
+ >>> df1_col_order = spark.createDataFrame(
+ ... data=[(1000, "1"), (5000, "2")], schema=["amount", "id"]
+ ... )
+ >>> df2_col_order = spark.createDataFrame(
+ ... data=[("1", 1000), ("2", 5000)], schema=["id", "amount"]
+ ... )
+ >>> assertDataFrameEqual(df1_col_order, df2_col_order,
ignoreColumnOrder=True)
+
+ Example for ignoreColumnName
+
+ >>> df1_col_names = spark.createDataFrame(
+ ... data=[(1000, "1"), (5000, "2")], schema=["amount", "identity"]
+ ... )
+ >>> df2_col_names = spark.createDataFrame(
+ ... data=[(1000, "1"), (5000, "2")], schema=["amount", "id"]
+ ... )
+ >>> assertDataFrameEqual(df1_col_names, df2_col_names,
ignoreColumnName=True)
+
+ Example for ignoreColumnType
+
+ >>> df1_col_types = spark.createDataFrame(
+ ... data=[(1000, "1"), (5000, "2")], schema=["amount", "id"]
+ ... )
+ >>> df2_col_types = spark.createDataFrame(
+ ... data=[(1000.0, "1"), (5000.0, "2")], schema=["amount", "id"]
+ ... )
+ >>> assertDataFrameEqual(df1_col_types, df2_col_types,
ignoreColumnType=True)
+
+ Example for maxErrors (will only report the first mismatching row)
+
+ >>> df1 = spark.createDataFrame([(1, "A"), (2, "B"), (3, "C")])
+ >>> df2 = spark.createDataFrame([(1, "A"), (2, "X"), (3, "Y")])
+ >>> assertDataFrameEqual(df1, df2, maxErrors=1) # doctest:
+IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ PySparkAssertionError: [DIFFERENT_ROWS] Results do not match: ( 33.33333 %
)
+ *** actual ***
+ Row(_1=1, _2='A')
+ ! Row(_1=2, _2='B')
+ *** expected ***
+ Row(_1=1, _2='A')
+ ! Row(_1=2, _2='X')
+
+ Example for showOnlyDiff (will only report the mismatching rows)
+
+ >>> df1 = spark.createDataFrame([(1, "A"), (2, "B"), (3, "C")])
+ >>> df2 = spark.createDataFrame([(1, "A"), (2, "X"), (3, "Y")])
+ >>> assertDataFrameEqual(df1, df2, showOnlyDiff=True) # doctest:
+IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ PySparkAssertionError: [DIFFERENT_ROWS] Results do not match: ( 66.66667 %
)
+ *** actual ***
+ ! Row(_1=2, _2='B')
+ ! Row(_1=3, _2='C')
+ *** expected ***
+ ! Row(_1=2, _2='X')
+ ! Row(_1=3, _2='Y')
"""
if actual is None and expected is None:
return True
@@ -546,6 +695,37 @@ def assertDataFrameEqual(
},
)
+ if ignoreColumnOrder:
+ actual = actual.select(*sorted(actual.columns))
+ expected = expected.select(*sorted(expected.columns))
+
+ def rename_dataframe_columns(df: DataFrame) -> DataFrame:
+ """Rename DataFrame columns to sequential numbers for comparison"""
+ renamed_columns = [str(i) for i in range(len(df.columns))]
+ return df.toDF(*renamed_columns)
+
+ if ignoreColumnName:
+ actual = rename_dataframe_columns(actual)
+ expected = rename_dataframe_columns(expected)
+
+ def cast_columns_to_string(df: DataFrame) -> DataFrame:
+ """Cast all DataFrame columns to string for comparison"""
+ for col_name in df.columns:
+ # Add logic to remove trailing .0 for float columns that are whole
numbers
+ df = df.withColumn(
+ col_name,
+ when(
+ (col(col_name).cast("float").isNotNull())
+ & (col(col_name).cast("float") ==
col(col_name).cast("int")),
+ col(col_name).cast("int").cast("string"),
+ ).otherwise(col(col_name).cast("string")),
+ )
+ return df
+
+ if ignoreColumnType:
+ actual = cast_columns_to_string(actual)
+ expected = cast_columns_to_string(expected)
+
def compare_rows(r1: Row, r2: Row):
def compare_vals(val1, val2):
if isinstance(val1, list) and isinstance(val2, list):
@@ -578,7 +758,9 @@ def assertDataFrameEqual(
return compare_vals(r1, r2)
- def assert_rows_equal(rows1: List[Row], rows2: List[Row]):
+ def assert_rows_equal(
+ rows1: List[Row], rows2: List[Row], maxErrors: int = None,
showOnlyDiff: bool = False
+ ):
zipped = list(zip_longest(rows1, rows2))
diff_rows_cnt = 0
diff_rows = False
@@ -588,11 +770,16 @@ def assertDataFrameEqual(
# count different rows
for r1, r2 in zipped:
- rows_str1 += str(r1) + "\n"
- rows_str2 += str(r2) + "\n"
if not compare_rows(r1, r2):
diff_rows_cnt += 1
diff_rows = True
+ rows_str1 += str(r1) + "\n"
+ rows_str2 += str(r2) + "\n"
+ if maxErrors is not None and diff_rows_cnt >= maxErrors:
+ break
+ elif not showOnlyDiff:
+ rows_str1 += str(r1) + "\n"
+ rows_str2 += str(r2) + "\n"
generated_diff = _context_diff(
actual=rows_str1.splitlines(), expected=rows_str2.splitlines(),
n=len(zipped)
@@ -608,10 +795,20 @@ def assertDataFrameEqual(
message_parameters={"error_msg": error_msg},
)
- # convert actual and expected to list
+ # only compare schema if expected is not a List
if not isinstance(actual, list) and not isinstance(expected, list):
- # only compare schema if expected is not a List
- assertSchemaEqual(actual.schema, expected.schema)
+ if ignoreNullable:
+ assertSchemaEqual(actual.schema, expected.schema)
+ elif actual.schema != expected.schema:
+ generated_diff = difflib.ndiff(
+ str(actual.schema).splitlines(),
str(expected.schema).splitlines()
+ )
+ error_msg = "\n".join(generated_diff)
+
+ raise PySparkAssertionError(
+ error_class="DIFFERENT_SCHEMA",
+ message_parameters={"error_msg": error_msg},
+ )
if not isinstance(actual, list):
actual_list = actual.collect()
@@ -628,7 +825,7 @@ def assertDataFrameEqual(
actual_list = sorted(actual_list, key=lambda x: str(x))
expected_list = sorted(expected_list, key=lambda x: str(x))
- assert_rows_equal(actual_list, expected_list)
+ assert_rows_equal(actual_list, expected_list, maxErrors=maxErrors,
showOnlyDiff=showOnlyDiff)
def _test() -> None:
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]