ueshin commented on code in PR #41606:
URL: https://github.com/apache/spark/pull/41606#discussion_r1254746622
##########
python/pyspark/testing/utils.py:
##########
@@ -209,3 +232,144 @@ def check_error(
self.assertEqual(
expected, actual, f"Expected message parameters was '{expected}',
got '{actual}'"
)
+
+
+def assertDataFrameEqual(
+ df: DataFrame, expected: Union[DataFrame, List[Row]], ignore_row_order:
bool = True
+):
+ """
+ A util function to assert equality between DataFrames `df` and `expected`,
with
+ optional parameter `ignore_row_order`.
+
+ For float values, assert approximate equality (1e-5 by default).
+
+ Parameters
+ ----------
+ df : DataFrame
+ expected : DataFrame or List of Row
+ ignore_row_order: bool, default True
+ """
+ if df is None and expected is None:
+ return True
+ elif df is None or expected is None:
+ return False
+
+ try:
+ # If Spark Connect dependencies are available, allow Spark Connect
DataFrame
+ from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
+
+ if not isinstance(df, DataFrame) and not isinstance(df,
ConnectDataFrame):
+ raise PySparkAssertionError(
+ error_class="UNSUPPORTED_DATA_TYPE",
+ message_parameters={"data_type": type(df)},
+ )
+ elif not isinstance(expected, DataFrame) and not isinstance(expected,
ConnectDataFrame):
+ raise PySparkAssertionError(
+ error_class="UNSUPPORTED_DATA_TYPE",
+ message_parameters={"data_type": type(expected)},
+ )
+ except:
+ if not isinstance(df, DataFrame):
+ raise PySparkAssertionError(
+ error_class="UNSUPPORTED_DATA_TYPE",
+ message_parameters={"data_type": type(df)},
+ )
+ elif not isinstance(expected, DataFrame):
+ raise PySparkAssertionError(
+ error_class="UNSUPPORTED_DATA_TYPE",
+ message_parameters={"data_type": type(expected)},
+ )
+
+ def rename_duplicate_cols(input_df):
+ df_cols = input_df.columns
+
+ duplicate_col_indices = [idx for idx, val in enumerate(df_cols) if val
in df_cols[:idx]]
+
+ # Create a new list by renaming duplicate
+ # columns by adding prefix '_duplicate_'+index
+ for i in duplicate_col_indices:
+ df_cols[i] = df_cols[i] + "_duplicate_" + str(i)
+
+ # Rename duplicate columns
+ result_df = input_df.toDF(*df_cols)
+
+ return result_df
+
+ def compare_rows(r1: Row, r2: Row):
+ def compare_vals(val1, val2):
+ if isinstance(val1, list) and isinstance(val2, list):
+ return len(val1) == len(val2) and all(
+ compare_vals(x, y) for x, y in zip(val1, val2)
+ )
+ elif isinstance(val1, Row) and isinstance(val2, Row):
+ return all(compare_vals(x, y) for x, y in zip(val1, val2))
+ elif isinstance(val1, dict) and isinstance(val2, dict):
+ return (
+ len(val1.keys()) == len(val2.keys())
+ and val1.keys() == val2.keys()
+ and all(compare_vals(val1[k], val2[k]) for k in
val1.keys())
+ )
+ elif isinstance(val1, float) and isinstance(val2, float):
+ if abs(val1 - val2) > 1e-5:
+ return False
+ else:
+ if val1 != val2:
+ return False
+ return True
+
+ if r1 is None and r2 is None:
+ return True
+ elif r1 is None or r2 is None:
+ return False
+
+ return compare_vals(r1, r2)
+
+ def assert_schema_equal(
+ df_schema: StructType,
+ expected_schema: StructType,
+ ):
+ if df_schema != expected_schema:
+ raise PySparkAssertionError(
+ error_class="DIFFERENT_SCHEMA",
+ message_parameters={"df_schema": df_schema, "expected_schema":
expected_schema},
+ )
+
+ def assert_rows_equal(rows1: Row, rows2: Row):
+ zipped = list(zip_longest(rows1, rows2))
+ rows_equal = True
+ error_msg = "Results do not match: "
+ diff_msg = ""
+ diff_rows_cnt = 0
+
+ for r1, r2 in zipped:
+ if not compare_rows(r1, r2):
+ rows_equal = False
+ diff_rows_cnt += 1
+ diff_msg += (
+ "[df]" + "\n" + str(r1) + "\n\n" + "[expected]" + "\n" +
str(r2) + "\n\n"
+ )
+ diff_msg += "********************" + "\n\n"
+
+ if not rows_equal:
+ percent_diff = diff_rows_cnt / len(zipped)
+ error_msg += "( %.5f %% )" % percent_diff
+ error_msg += "\n" + diff_msg
+ raise PySparkAssertionError(
+ error_class="DIFFERENT_ROWS",
+ message_parameters={"error_msg": error_msg},
+ )
+
+ if ignore_row_order:
+ try:
+ df = rename_duplicate_cols(df)
+ expected = rename_duplicate_cols(expected)
Review Comment:
If we need, the current implementation is not enough because it doesn't
handle nested struct field names.
##########
python/pyspark/testing/utils.py:
##########
@@ -209,3 +232,144 @@ def check_error(
self.assertEqual(
expected, actual, f"Expected message parameters was '{expected}',
got '{actual}'"
)
+
+
+def assertDataFrameEqual(
+ df: DataFrame, expected: Union[DataFrame, List[Row]], ignore_row_order:
bool = True
+):
+ """
+ A util function to assert equality between DataFrames `df` and `expected`,
with
+ optional parameter `ignore_row_order`.
+
+ For float values, assert approximate equality (1e-5 by default).
+
+ Parameters
+ ----------
+ df : DataFrame
+ expected : DataFrame or List of Row
+ ignore_row_order: bool, default True
+ """
+ if df is None and expected is None:
+ return True
+ elif df is None or expected is None:
+ return False
+
+ try:
+ # If Spark Connect dependencies are available, allow Spark Connect
DataFrame
+ from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
+
+ if not isinstance(df, DataFrame) and not isinstance(df,
ConnectDataFrame):
+ raise PySparkAssertionError(
+ error_class="UNSUPPORTED_DATA_TYPE",
+ message_parameters={"data_type": type(df)},
+ )
+ elif not isinstance(expected, DataFrame) and not isinstance(expected,
ConnectDataFrame):
+ raise PySparkAssertionError(
+ error_class="UNSUPPORTED_DATA_TYPE",
+ message_parameters={"data_type": type(expected)},
+ )
+ except:
+ if not isinstance(df, DataFrame):
+ raise PySparkAssertionError(
+ error_class="UNSUPPORTED_DATA_TYPE",
+ message_parameters={"data_type": type(df)},
+ )
+ elif not isinstance(expected, DataFrame):
+ raise PySparkAssertionError(
+ error_class="UNSUPPORTED_DATA_TYPE",
+ message_parameters={"data_type": type(expected)},
+ )
+
+ def rename_duplicate_cols(input_df):
+ df_cols = input_df.columns
+
+ duplicate_col_indices = [idx for idx, val in enumerate(df_cols) if val
in df_cols[:idx]]
+
+ # Create a new list by renaming duplicate
+ # columns by adding prefix '_duplicate_'+index
+ for i in duplicate_col_indices:
+ df_cols[i] = df_cols[i] + "_duplicate_" + str(i)
+
+ # Rename duplicate columns
+ result_df = input_df.toDF(*df_cols)
+
+ return result_df
+
+ def compare_rows(r1: Row, r2: Row):
+ def compare_vals(val1, val2):
+ if isinstance(val1, list) and isinstance(val2, list):
+ return len(val1) == len(val2) and all(
+ compare_vals(x, y) for x, y in zip(val1, val2)
+ )
+ elif isinstance(val1, Row) and isinstance(val2, Row):
+ return all(compare_vals(x, y) for x, y in zip(val1, val2))
+ elif isinstance(val1, dict) and isinstance(val2, dict):
+ return (
+ len(val1.keys()) == len(val2.keys())
+ and val1.keys() == val2.keys()
+ and all(compare_vals(val1[k], val2[k]) for k in
val1.keys())
+ )
+ elif isinstance(val1, float) and isinstance(val2, float):
+ if abs(val1 - val2) > 1e-5:
+ return False
+ else:
+ if val1 != val2:
+ return False
+ return True
+
+ if r1 is None and r2 is None:
+ return True
+ elif r1 is None or r2 is None:
+ return False
+
+ return compare_vals(r1, r2)
+
+ def assert_schema_equal(
+ df_schema: StructType,
+ expected_schema: StructType,
+ ):
+ if df_schema != expected_schema:
+ raise PySparkAssertionError(
+ error_class="DIFFERENT_SCHEMA",
+ message_parameters={"df_schema": df_schema, "expected_schema":
expected_schema},
+ )
+
+ def assert_rows_equal(rows1: Row, rows2: Row):
+ zipped = list(zip_longest(rows1, rows2))
+ rows_equal = True
+ error_msg = "Results do not match: "
+ diff_msg = ""
+ diff_rows_cnt = 0
+
+ for r1, r2 in zipped:
+ if not compare_rows(r1, r2):
+ rows_equal = False
+ diff_rows_cnt += 1
+ diff_msg += (
+ "[df]" + "\n" + str(r1) + "\n\n" + "[expected]" + "\n" +
str(r2) + "\n\n"
+ )
+ diff_msg += "********************" + "\n\n"
+
+ if not rows_equal:
+ percent_diff = diff_rows_cnt / len(zipped)
+ error_msg += "( %.5f %% )" % percent_diff
+ error_msg += "\n" + diff_msg
+ raise PySparkAssertionError(
+ error_class="DIFFERENT_ROWS",
+ message_parameters={"error_msg": error_msg},
+ )
+
+ if ignore_row_order:
+ try:
+ df = rename_duplicate_cols(df)
+ expected = rename_duplicate_cols(expected)
Review Comment:
Do we need this?
I think `assert_schema_equal` will check the duplicated names properly, and
`assert_rows_equal` doesn't need to check column/field names anymore.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]