allisonwang-db commented on code in PR #41606:
URL: https://github.com/apache/spark/pull/41606#discussion_r1253826970
##########
python/pyspark/testing/utils.py:
##########
@@ -209,3 +234,80 @@ def check_error(
self.assertEqual(
expected, actual, f"Expected message parameters was '{expected}',
got '{actual}'"
)
+
+
+def assertDataFrameEqual(df: DataFrame, expected: DataFrame, ignore_row_order:
bool = True):
Review Comment:
Let's add docstring for this method and also include some details on how we
approximate the results.
This can be a follow-up for this PR. For the expected value, we should allow
users to pass in a Row or a list of rows, e.g.
```
assertDataFrameEqual(df, Row(a=1))
assertDataFrameEqual(df, [Row(a=1), Row(a=2)])
```
##########
python/pyspark/sql/tests/test_utils.py:
##########
@@ -16,18 +16,498 @@
# limitations under the License.
#
-from pyspark.sql.functions import sha2
+import unittest
+from prettytable import PrettyTable
+
+from pyspark.sql.functions import sha2, to_timestamp
from pyspark.errors import (
AnalysisException,
ParseException,
+ PySparkAssertionError,
IllegalArgumentException,
SparkUpgradeException,
)
-from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.utils import assertDataFrameEqual, blue, red
+from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas,
pandas_requirement_message
+import pyspark.sql.functions as F
from pyspark.sql.functions import to_date, unix_timestamp, from_unixtime
+from pyspark.sql.types import (
+ StringType,
+ ArrayType,
+ LongType,
+ StructType,
+ MapType,
+ FloatType,
+ DoubleType,
+ StructField,
+ TimestampType,
+)
+
+
+class UtilsTestsMixin:
+ def test_assert_equal_inttype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", 3000),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", 3000),
+ ],
+ schema=["id", "amount"],
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_equal_arraytype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("john", ["Python", "Java"]),
+ ("jane", ["Scala", "SQL", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("john", ["Python", "Java"]),
+ ("jane", ["Scala", "SQL", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_approx_equal_arraytype_float(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", [97.01, 89.23]),
+ ("student2", [91.86, 84.34]),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", ArrayType(FloatType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", [97.01, 89.23]),
+ ("student2", [91.86, 84.339999]),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", ArrayType(FloatType()), True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_notequal_arraytype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("John", ["Python", "Java"]),
+ ("Jane", ["Scala", "SQL", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("John", ["Python", "Java"]),
+ ("Jane", ["Scala", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+
+ expected_error_table = PrettyTable(["df", "expected"])
+ expected_error_table.add_row(
+ [red(df1.sort(df1.columns).collect()[0]),
red(df2.sort(df2.columns).collect()[0])]
+ )
+ expected_error_table.add_row(
+ [blue(df1.sort(df1.columns).collect()[1]),
blue(df2.sort(df2.columns).collect()[1])]
+ )
+
+ with self.assertRaises(PySparkAssertionError) as pe:
+ assertDataFrameEqual(df1, df2)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="DIFFERENT_DATAFRAME",
+ message_parameters={"error_table":
expected_error_table.get_string()},
+ )
+
+ def test_assert_equal_maptype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"id": 222342203655477580}),
+ ("student2", {"id": 422322203155477692}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("properties", MapType(StringType(),
LongType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"id": 222342203655477580}),
+ ("student2", {"id": 422322203155477692}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("properties", MapType(StringType(),
LongType()), True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_approx_equal_maptype_double(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.64}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.63999999}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_approx_equal_maptype_double(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.64}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.63999999}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_approx_equal_nested_struct_double(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("jane", (64.57, 76.63, 97.81)),
+ ("john", (93.92, 91.57, 84.36)),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField(
+ "grades",
+ StructType(
+ [
+ StructField("math", DoubleType(), True),
+ StructField("english", DoubleType(), True),
+ StructField("biology", DoubleType(), True),
+ ]
+ ),
+ ),
+ ]
+ ),
+ )
+
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("jane", (64.57, 76.63, 97.81000001)),
+ ("john", (93.92, 91.57, 84.36)),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField(
+ "grades",
+ StructType(
+ [
+ StructField("math", DoubleType(), True),
+ StructField("english", DoubleType(), True),
+ StructField("biology", DoubleType(), True),
+ ]
+ ),
+ ),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_equal_timestamp(self):
+ df1 = self.spark.createDataFrame(
+ data=[("1", "2023-01-01 12:01:01.000")], schema=["id", "timestamp"]
+ )
+
+ df2 = self.spark.createDataFrame(
+ data=[("1", "2023-01-01 12:01:01.000")], schema=["id", "timestamp"]
+ )
+
+ df1 = df1.withColumn("timestamp", to_timestamp("timestamp"))
+ df2 = df2.withColumn("timestamp", to_timestamp("timestamp"))
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_equal_nullrow(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ (None, None),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ (None, None),
+ ],
+ schema=["id", "amount"],
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_notequal_nullval(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", 2000),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", None),
+ ],
+ schema=["id", "amount"],
+ )
+
+ expected_error_table = PrettyTable(["df", "expected"])
+ expected_error_table.add_row(
+ [blue(df1.sort(df1.columns).collect()[0]),
blue(df2.sort(df2.columns).collect()[0])]
+ )
+ expected_error_table.add_row(
+ [red(df1.sort(df1.columns).collect()[1]),
red(df2.sort(df2.columns).collect()[1])]
+ )
+
+ with self.assertRaises(PySparkAssertionError) as pe:
+ assertDataFrameEqual(df1, df2)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="DIFFERENT_DATAFRAME",
+ message_parameters={"error_table":
expected_error_table.get_string()},
+ )
+
+ def test_assert_equal_nulldf(self):
+ df1 = None
+ df2 = None
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_ignore_row_order(self):
+ # test that row order is ignored by default
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("2", 3000.00),
+ ("1", 1000.00),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000.00),
+ ("2", 3000.00),
+ ],
+ schema=["id", "amount"],
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def remove_non_word_characters(self, col):
+ return F.regexp_replace(col, "[^\\w\\s]+", "")
+
+ def test_remove_non_word_characters_long(self):
+ source_data = [("jo&&se",), ("**li**",), ("#::luisa",), (None,)]
+ source_df = self.spark.createDataFrame(source_data, ["name"])
+
+ actual_df = source_df.withColumn(
+ "clean_name", self.remove_non_word_characters(F.col("name"))
+ )
+
+ expected_data = [("jo&&se", "jose"), ("**li**", "li"), ("#::luisa",
"luisa"), (None, None)]
+ expected_df = self.spark.createDataFrame(expected_data, ["name",
"clean_name"])
+
+ assertDataFrameEqual(actual_df, expected_df)
+
+ def test_assert_pyspark_approx_equal(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000.00),
+ ("2", 3000.00),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000.0000001),
+ ("2", 3000.00),
+ ],
+ schema=["id", "amount"],
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_pyspark_df_not_equal(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000.00),
+ ("2", 3000.00),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1001.00),
+ ("2", 3000.00),
+ ],
+ schema=["id", "amount"],
+ )
+
+ expected_error_table = PrettyTable(["df", "expected"])
+ expected_error_table.add_row([red(df1.collect()[0]),
red(df2.collect()[0])])
+ expected_error_table.add_row([blue(df1.collect()[1]),
blue(df2.collect()[1])])
+
+ with self.assertRaises(PySparkAssertionError) as pe:
+ assertDataFrameEqual(df1, df2)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="DIFFERENT_DATAFRAME",
+ message_parameters={"error_table":
expected_error_table.get_string()},
+ )
+
+ def test_assert_notequal_schema(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ (1, 1000),
+ (2, 3000),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", 3000),
+ ],
+ schema=["id", "amount"],
+ )
+
+ with self.assertRaises(PySparkAssertionError) as pe:
+ assertDataFrameEqual(df1, df2)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="DIFFERENT_SCHEMA",
+ message_parameters={"df_schema": df1.schema, "expected_schema":
df2.schema},
+ )
+
+ def test_assert_equal_maptype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"id": 222342203655477580}),
+ ("student2", {"grad_year": 422322203155477692}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("properties", MapType(StringType(),
LongType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"id": 222342203655477580}),
+ ("student2", {"id": 422322203155477692}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("properties", MapType(StringType(),
LongType()), True),
+ ]
+ ),
+ )
+
+ with self.assertRaises(PySparkAssertionError) as pe:
+ assertDataFrameEqual(df1, df2)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="UNSUPPORTED_DATA_TYPE_FOR_IGNORE_ROW_ORDER",
+ message_parameters={},
+ )
Review Comment:
Can we also add more tests using spark.sql? E.g
```
assertDataFrameEqual(
self.spark.sql("select 1 + 2 AS x"),
self.spark.sql("select 3 AS x")
)
```
Also one that sorts the results: (which we should use ignore_row_order=False)
```
assertDataFrameEqual(
self.spark.sql("select * from ... order by ...")
self.spark.sql("...")
)
```
You can also try another one that has an invalid SQL query (to see what the
error message looks like):
```
assertDataFrameEqual(
self.spark.sql("select non-existing-column from ...")
self.spark.sql("...")
)
```
##########
python/pyspark/sql/tests/test_utils.py:
##########
@@ -16,18 +16,498 @@
# limitations under the License.
#
-from pyspark.sql.functions import sha2
+import unittest
+from prettytable import PrettyTable
+
+from pyspark.sql.functions import sha2, to_timestamp
from pyspark.errors import (
AnalysisException,
ParseException,
+ PySparkAssertionError,
IllegalArgumentException,
SparkUpgradeException,
)
-from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.utils import assertDataFrameEqual, blue, red
+from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas,
pandas_requirement_message
+import pyspark.sql.functions as F
from pyspark.sql.functions import to_date, unix_timestamp, from_unixtime
+from pyspark.sql.types import (
+ StringType,
+ ArrayType,
+ LongType,
+ StructType,
+ MapType,
+ FloatType,
+ DoubleType,
+ StructField,
+ TimestampType,
+)
+
+
+class UtilsTestsMixin:
+ def test_assert_equal_inttype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", 3000),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", 3000),
+ ],
+ schema=["id", "amount"],
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_equal_arraytype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("john", ["Python", "Java"]),
+ ("jane", ["Scala", "SQL", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("john", ["Python", "Java"]),
+ ("jane", ["Scala", "SQL", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_approx_equal_arraytype_float(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", [97.01, 89.23]),
+ ("student2", [91.86, 84.34]),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", ArrayType(FloatType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", [97.01, 89.23]),
+ ("student2", [91.86, 84.339999]),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", ArrayType(FloatType()), True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_notequal_arraytype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("John", ["Python", "Java"]),
+ ("Jane", ["Scala", "SQL", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("John", ["Python", "Java"]),
+ ("Jane", ["Scala", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+
+ expected_error_table = PrettyTable(["df", "expected"])
+ expected_error_table.add_row(
+ [red(df1.sort(df1.columns).collect()[0]),
red(df2.sort(df2.columns).collect()[0])]
+ )
+ expected_error_table.add_row(
+ [blue(df1.sort(df1.columns).collect()[1]),
blue(df2.sort(df2.columns).collect()[1])]
+ )
+
+ with self.assertRaises(PySparkAssertionError) as pe:
+ assertDataFrameEqual(df1, df2)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="DIFFERENT_DATAFRAME",
+ message_parameters={"error_table":
expected_error_table.get_string()},
+ )
+
+ def test_assert_equal_maptype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"id": 222342203655477580}),
+ ("student2", {"id": 422322203155477692}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("properties", MapType(StringType(),
LongType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"id": 222342203655477580}),
+ ("student2", {"id": 422322203155477692}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("properties", MapType(StringType(),
LongType()), True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_approx_equal_maptype_double(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.64}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.63999999}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_approx_equal_maptype_double(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.64}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.63999999}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_approx_equal_nested_struct_double(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("jane", (64.57, 76.63, 97.81)),
+ ("john", (93.92, 91.57, 84.36)),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField(
+ "grades",
+ StructType(
+ [
+ StructField("math", DoubleType(), True),
+ StructField("english", DoubleType(), True),
+ StructField("biology", DoubleType(), True),
+ ]
+ ),
+ ),
+ ]
+ ),
+ )
+
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("jane", (64.57, 76.63, 97.81000001)),
+ ("john", (93.92, 91.57, 84.36)),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField(
+ "grades",
+ StructType(
+ [
+ StructField("math", DoubleType(), True),
+ StructField("english", DoubleType(), True),
+ StructField("biology", DoubleType(), True),
+ ]
+ ),
+ ),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_equal_timestamp(self):
+ df1 = self.spark.createDataFrame(
+ data=[("1", "2023-01-01 12:01:01.000")], schema=["id", "timestamp"]
+ )
+
+ df2 = self.spark.createDataFrame(
+ data=[("1", "2023-01-01 12:01:01.000")], schema=["id", "timestamp"]
+ )
+
+ df1 = df1.withColumn("timestamp", to_timestamp("timestamp"))
+ df2 = df2.withColumn("timestamp", to_timestamp("timestamp"))
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_equal_nullrow(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ (None, None),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ (None, None),
+ ],
+ schema=["id", "amount"],
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_notequal_nullval(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", 2000),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", None),
+ ],
+ schema=["id", "amount"],
+ )
+
+ expected_error_table = PrettyTable(["df", "expected"])
+ expected_error_table.add_row(
+ [blue(df1.sort(df1.columns).collect()[0]),
blue(df2.sort(df2.columns).collect()[0])]
+ )
+ expected_error_table.add_row(
+ [red(df1.sort(df1.columns).collect()[1]),
red(df2.sort(df2.columns).collect()[1])]
+ )
+
+ with self.assertRaises(PySparkAssertionError) as pe:
+ assertDataFrameEqual(df1, df2)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="DIFFERENT_DATAFRAME",
+ message_parameters={"error_table":
expected_error_table.get_string()},
+ )
+
+ def test_assert_equal_nulldf(self):
+ df1 = None
+ df2 = None
+
+ assertDataFrameEqual(df1, df2)
Review Comment:
How about any other values? Like a dictionary or list?
I think we should check whether the type for df1 or df2 is pd.DataFrame in
`assertDataFrameEqual`
##########
python/pyspark/sql/tests/test_utils.py:
##########
@@ -16,18 +16,498 @@
# limitations under the License.
#
-from pyspark.sql.functions import sha2
+import unittest
+from prettytable import PrettyTable
+
+from pyspark.sql.functions import sha2, to_timestamp
from pyspark.errors import (
AnalysisException,
ParseException,
+ PySparkAssertionError,
IllegalArgumentException,
SparkUpgradeException,
)
-from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.utils import assertDataFrameEqual, blue, red
+from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas,
pandas_requirement_message
+import pyspark.sql.functions as F
from pyspark.sql.functions import to_date, unix_timestamp, from_unixtime
+from pyspark.sql.types import (
+ StringType,
+ ArrayType,
+ LongType,
+ StructType,
+ MapType,
+ FloatType,
+ DoubleType,
+ StructField,
+ TimestampType,
+)
+
+
+class UtilsTestsMixin:
+ def test_assert_equal_inttype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", 3000),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", 3000),
+ ],
+ schema=["id", "amount"],
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_equal_arraytype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("john", ["Python", "Java"]),
+ ("jane", ["Scala", "SQL", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("john", ["Python", "Java"]),
+ ("jane", ["Scala", "SQL", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_approx_equal_arraytype_float(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", [97.01, 89.23]),
+ ("student2", [91.86, 84.34]),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", ArrayType(FloatType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", [97.01, 89.23]),
+ ("student2", [91.86, 84.339999]),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", ArrayType(FloatType()), True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_notequal_arraytype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("John", ["Python", "Java"]),
+ ("Jane", ["Scala", "SQL", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("John", ["Python", "Java"]),
+ ("Jane", ["Scala", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+
+ expected_error_table = PrettyTable(["df", "expected"])
+ expected_error_table.add_row(
+ [red(df1.sort(df1.columns).collect()[0]),
red(df2.sort(df2.columns).collect()[0])]
+ )
+ expected_error_table.add_row(
+ [blue(df1.sort(df1.columns).collect()[1]),
blue(df2.sort(df2.columns).collect()[1])]
+ )
+
+ with self.assertRaises(PySparkAssertionError) as pe:
+ assertDataFrameEqual(df1, df2)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="DIFFERENT_DATAFRAME",
+ message_parameters={"error_table":
expected_error_table.get_string()},
+ )
+
+ def test_assert_equal_maptype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"id": 222342203655477580}),
+ ("student2", {"id": 422322203155477692}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("properties", MapType(StringType(),
LongType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"id": 222342203655477580}),
+ ("student2", {"id": 422322203155477692}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("properties", MapType(StringType(),
LongType()), True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_approx_equal_maptype_double(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.64}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.63999999}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_approx_equal_maptype_double(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.64}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.63999999}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_approx_equal_nested_struct_double(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("jane", (64.57, 76.63, 97.81)),
+ ("john", (93.92, 91.57, 84.36)),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField(
+ "grades",
+ StructType(
+ [
+ StructField("math", DoubleType(), True),
+ StructField("english", DoubleType(), True),
+ StructField("biology", DoubleType(), True),
+ ]
+ ),
+ ),
+ ]
+ ),
+ )
+
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("jane", (64.57, 76.63, 97.81000001)),
+ ("john", (93.92, 91.57, 84.36)),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField(
+ "grades",
+ StructType(
+ [
+ StructField("math", DoubleType(), True),
+ StructField("english", DoubleType(), True),
+ StructField("biology", DoubleType(), True),
+ ]
+ ),
+ ),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_equal_timestamp(self):
+ df1 = self.spark.createDataFrame(
+ data=[("1", "2023-01-01 12:01:01.000")], schema=["id", "timestamp"]
+ )
+
+ df2 = self.spark.createDataFrame(
+ data=[("1", "2023-01-01 12:01:01.000")], schema=["id", "timestamp"]
+ )
+
+ df1 = df1.withColumn("timestamp", to_timestamp("timestamp"))
+ df2 = df2.withColumn("timestamp", to_timestamp("timestamp"))
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
Review Comment:
Why do we need `ignore_row_order=False` here?
##########
python/pyspark/sql/tests/test_utils.py:
##########
@@ -16,18 +16,498 @@
# limitations under the License.
#
-from pyspark.sql.functions import sha2
+import unittest
+from prettytable import PrettyTable
+
+from pyspark.sql.functions import sha2, to_timestamp
from pyspark.errors import (
AnalysisException,
ParseException,
+ PySparkAssertionError,
IllegalArgumentException,
SparkUpgradeException,
)
-from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.utils import assertDataFrameEqual, blue, red
+from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas,
pandas_requirement_message
+import pyspark.sql.functions as F
from pyspark.sql.functions import to_date, unix_timestamp, from_unixtime
+from pyspark.sql.types import (
+ StringType,
+ ArrayType,
+ LongType,
+ StructType,
+ MapType,
+ FloatType,
+ DoubleType,
+ StructField,
+ TimestampType,
+)
+
+
+class UtilsTestsMixin:
+ def test_assert_equal_inttype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", 3000),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", 3000),
+ ],
+ schema=["id", "amount"],
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_equal_arraytype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("john", ["Python", "Java"]),
+ ("jane", ["Scala", "SQL", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("john", ["Python", "Java"]),
+ ("jane", ["Scala", "SQL", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_approx_equal_arraytype_float(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", [97.01, 89.23]),
+ ("student2", [91.86, 84.34]),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", ArrayType(FloatType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", [97.01, 89.23]),
+ ("student2", [91.86, 84.339999]),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", ArrayType(FloatType()), True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_notequal_arraytype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("John", ["Python", "Java"]),
+ ("Jane", ["Scala", "SQL", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("John", ["Python", "Java"]),
+ ("Jane", ["Scala", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+
+ expected_error_table = PrettyTable(["df", "expected"])
+ expected_error_table.add_row(
+ [red(df1.sort(df1.columns).collect()[0]),
red(df2.sort(df2.columns).collect()[0])]
+ )
+ expected_error_table.add_row(
+ [blue(df1.sort(df1.columns).collect()[1]),
blue(df2.sort(df2.columns).collect()[1])]
+ )
+
+ with self.assertRaises(PySparkAssertionError) as pe:
+ assertDataFrameEqual(df1, df2)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="DIFFERENT_DATAFRAME",
+ message_parameters={"error_table":
expected_error_table.get_string()},
+ )
+
+ def test_assert_equal_maptype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"id": 222342203655477580}),
+ ("student2", {"id": 422322203155477692}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("properties", MapType(StringType(),
LongType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"id": 222342203655477580}),
+ ("student2", {"id": 422322203155477692}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("properties", MapType(StringType(),
LongType()), True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_approx_equal_maptype_double(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.64}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.63999999}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_approx_equal_maptype_double(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.64}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.63999999}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_approx_equal_nested_struct_double(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("jane", (64.57, 76.63, 97.81)),
+ ("john", (93.92, 91.57, 84.36)),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField(
+ "grades",
+ StructType(
+ [
+ StructField("math", DoubleType(), True),
+ StructField("english", DoubleType(), True),
+ StructField("biology", DoubleType(), True),
+ ]
+ ),
+ ),
+ ]
+ ),
+ )
+
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("jane", (64.57, 76.63, 97.81000001)),
+ ("john", (93.92, 91.57, 84.36)),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField(
+ "grades",
+ StructType(
+ [
+ StructField("math", DoubleType(), True),
+ StructField("english", DoubleType(), True),
+ StructField("biology", DoubleType(), True),
+ ]
+ ),
+ ),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_equal_timestamp(self):
+ df1 = self.spark.createDataFrame(
+ data=[("1", "2023-01-01 12:01:01.000")], schema=["id", "timestamp"]
+ )
+
+ df2 = self.spark.createDataFrame(
+ data=[("1", "2023-01-01 12:01:01.000")], schema=["id", "timestamp"]
+ )
+
+ df1 = df1.withColumn("timestamp", to_timestamp("timestamp"))
+ df2 = df2.withColumn("timestamp", to_timestamp("timestamp"))
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_equal_nullrow(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ (None, None),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ (None, None),
+ ],
+ schema=["id", "amount"],
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_notequal_nullval(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", 2000),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", None),
+ ],
+ schema=["id", "amount"],
+ )
+
+ expected_error_table = PrettyTable(["df", "expected"])
+ expected_error_table.add_row(
+ [blue(df1.sort(df1.columns).collect()[0]),
blue(df2.sort(df2.columns).collect()[0])]
+ )
+ expected_error_table.add_row(
+ [red(df1.sort(df1.columns).collect()[1]),
red(df2.sort(df2.columns).collect()[1])]
+ )
+
+ with self.assertRaises(PySparkAssertionError) as pe:
+ assertDataFrameEqual(df1, df2)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="DIFFERENT_DATAFRAME",
+ message_parameters={"error_table":
expected_error_table.get_string()},
+ )
+
+ def test_assert_equal_nulldf(self):
+ df1 = None
+ df2 = None
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_ignore_row_order(self):
+ # test that row order is ignored by default
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("2", 3000.00),
+ ("1", 1000.00),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000.00),
+ ("2", 3000.00),
+ ],
+ schema=["id", "amount"],
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def remove_non_word_characters(self, col):
Review Comment:
We can put this function inside the test
`test_remove_non_word_characters_long`
##########
python/pyspark/sql/tests/test_utils.py:
##########
@@ -16,18 +16,498 @@
# limitations under the License.
#
-from pyspark.sql.functions import sha2
+import unittest
+from prettytable import PrettyTable
+
+from pyspark.sql.functions import sha2, to_timestamp
from pyspark.errors import (
AnalysisException,
ParseException,
+ PySparkAssertionError,
IllegalArgumentException,
SparkUpgradeException,
)
-from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.utils import assertDataFrameEqual, blue, red
+from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas,
pandas_requirement_message
+import pyspark.sql.functions as F
from pyspark.sql.functions import to_date, unix_timestamp, from_unixtime
+from pyspark.sql.types import (
+ StringType,
+ ArrayType,
+ LongType,
+ StructType,
+ MapType,
+ FloatType,
+ DoubleType,
+ StructField,
+ TimestampType,
+)
+
+
+class UtilsTestsMixin:
+ def test_assert_equal_inttype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", 3000),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", 3000),
+ ],
+ schema=["id", "amount"],
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_equal_arraytype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("john", ["Python", "Java"]),
+ ("jane", ["Scala", "SQL", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("john", ["Python", "Java"]),
+ ("jane", ["Scala", "SQL", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_approx_equal_arraytype_float(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", [97.01, 89.23]),
+ ("student2", [91.86, 84.34]),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", ArrayType(FloatType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", [97.01, 89.23]),
+ ("student2", [91.86, 84.339999]),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", ArrayType(FloatType()), True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_notequal_arraytype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("John", ["Python", "Java"]),
+ ("Jane", ["Scala", "SQL", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("John", ["Python", "Java"]),
+ ("Jane", ["Scala", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+
+ expected_error_table = PrettyTable(["df", "expected"])
+ expected_error_table.add_row(
+ [red(df1.sort(df1.columns).collect()[0]),
red(df2.sort(df2.columns).collect()[0])]
+ )
+ expected_error_table.add_row(
+ [blue(df1.sort(df1.columns).collect()[1]),
blue(df2.sort(df2.columns).collect()[1])]
+ )
+
+ with self.assertRaises(PySparkAssertionError) as pe:
+ assertDataFrameEqual(df1, df2)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="DIFFERENT_DATAFRAME",
+ message_parameters={"error_table":
expected_error_table.get_string()},
+ )
+
+ def test_assert_equal_maptype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"id": 222342203655477580}),
+ ("student2", {"id": 422322203155477692}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("properties", MapType(StringType(),
LongType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"id": 222342203655477580}),
+ ("student2", {"id": 422322203155477692}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("properties", MapType(StringType(),
LongType()), True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_approx_equal_maptype_double(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.64}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.63999999}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_approx_equal_maptype_double(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.64}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.63999999}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_approx_equal_nested_struct_double(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("jane", (64.57, 76.63, 97.81)),
+ ("john", (93.92, 91.57, 84.36)),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField(
+ "grades",
+ StructType(
+ [
+ StructField("math", DoubleType(), True),
+ StructField("english", DoubleType(), True),
+ StructField("biology", DoubleType(), True),
+ ]
+ ),
+ ),
+ ]
+ ),
+ )
+
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("jane", (64.57, 76.63, 97.81000001)),
+ ("john", (93.92, 91.57, 84.36)),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField(
+ "grades",
+ StructType(
+ [
+ StructField("math", DoubleType(), True),
+ StructField("english", DoubleType(), True),
+ StructField("biology", DoubleType(), True),
+ ]
+ ),
+ ),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2)
Review Comment:
Can we add another test for the nested struct type with a different schema?
##########
python/pyspark/sql/tests/test_utils.py:
##########
@@ -16,18 +16,498 @@
# limitations under the License.
#
-from pyspark.sql.functions import sha2
+import unittest
+from prettytable import PrettyTable
+
+from pyspark.sql.functions import sha2, to_timestamp
from pyspark.errors import (
AnalysisException,
ParseException,
+ PySparkAssertionError,
IllegalArgumentException,
SparkUpgradeException,
)
-from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.utils import assertDataFrameEqual, blue, red
+from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas,
pandas_requirement_message
+import pyspark.sql.functions as F
from pyspark.sql.functions import to_date, unix_timestamp, from_unixtime
+from pyspark.sql.types import (
+ StringType,
+ ArrayType,
+ LongType,
+ StructType,
+ MapType,
+ FloatType,
+ DoubleType,
+ StructField,
+ TimestampType,
+)
+
+
+class UtilsTestsMixin:
+ def test_assert_equal_inttype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", 3000),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", 3000),
+ ],
+ schema=["id", "amount"],
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_equal_arraytype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("john", ["Python", "Java"]),
+ ("jane", ["Scala", "SQL", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("john", ["Python", "Java"]),
+ ("jane", ["Scala", "SQL", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_approx_equal_arraytype_float(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", [97.01, 89.23]),
+ ("student2", [91.86, 84.34]),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", ArrayType(FloatType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", [97.01, 89.23]),
+ ("student2", [91.86, 84.339999]),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", ArrayType(FloatType()), True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_notequal_arraytype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("John", ["Python", "Java"]),
+ ("Jane", ["Scala", "SQL", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("John", ["Python", "Java"]),
+ ("Jane", ["Scala", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+
+ expected_error_table = PrettyTable(["df", "expected"])
+ expected_error_table.add_row(
+ [red(df1.sort(df1.columns).collect()[0]),
red(df2.sort(df2.columns).collect()[0])]
+ )
+ expected_error_table.add_row(
+ [blue(df1.sort(df1.columns).collect()[1]),
blue(df2.sort(df2.columns).collect()[1])]
+ )
+
+ with self.assertRaises(PySparkAssertionError) as pe:
+ assertDataFrameEqual(df1, df2)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="DIFFERENT_DATAFRAME",
+ message_parameters={"error_table":
expected_error_table.get_string()},
+ )
+
+ def test_assert_equal_maptype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"id": 222342203655477580}),
+ ("student2", {"id": 422322203155477692}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("properties", MapType(StringType(),
LongType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"id": 222342203655477580}),
+ ("student2", {"id": 422322203155477692}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("properties", MapType(StringType(),
LongType()), True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_approx_equal_maptype_double(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.64}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.63999999}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_approx_equal_maptype_double(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.64}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.63999999}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_approx_equal_nested_struct_double(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("jane", (64.57, 76.63, 97.81)),
+ ("john", (93.92, 91.57, 84.36)),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField(
+ "grades",
+ StructType(
+ [
+ StructField("math", DoubleType(), True),
+ StructField("english", DoubleType(), True),
+ StructField("biology", DoubleType(), True),
+ ]
+ ),
+ ),
+ ]
+ ),
+ )
+
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("jane", (64.57, 76.63, 97.81000001)),
+ ("john", (93.92, 91.57, 84.36)),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField(
+ "grades",
+ StructType(
+ [
+ StructField("math", DoubleType(), True),
+ StructField("english", DoubleType(), True),
+ StructField("biology", DoubleType(), True),
+ ]
+ ),
+ ),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_equal_timestamp(self):
+ df1 = self.spark.createDataFrame(
+ data=[("1", "2023-01-01 12:01:01.000")], schema=["id", "timestamp"]
+ )
+
+ df2 = self.spark.createDataFrame(
+ data=[("1", "2023-01-01 12:01:01.000")], schema=["id", "timestamp"]
+ )
+
+ df1 = df1.withColumn("timestamp", to_timestamp("timestamp"))
+ df2 = df2.withColumn("timestamp", to_timestamp("timestamp"))
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_equal_nullrow(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ (None, None),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ (None, None),
+ ],
+ schema=["id", "amount"],
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_notequal_nullval(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", 2000),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", None),
+ ],
+ schema=["id", "amount"],
+ )
+
+ expected_error_table = PrettyTable(["df", "expected"])
+ expected_error_table.add_row(
+ [blue(df1.sort(df1.columns).collect()[0]),
blue(df2.sort(df2.columns).collect()[0])]
+ )
+ expected_error_table.add_row(
+ [red(df1.sort(df1.columns).collect()[1]),
red(df2.sort(df2.columns).collect()[1])]
+ )
+
+ with self.assertRaises(PySparkAssertionError) as pe:
+ assertDataFrameEqual(df1, df2)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="DIFFERENT_DATAFRAME",
+ message_parameters={"error_table":
expected_error_table.get_string()},
+ )
+
+ def test_assert_equal_nulldf(self):
+ df1 = None
+ df2 = None
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_ignore_row_order(self):
+ # test that row order is ignored by default
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("2", 3000.00),
+ ("1", 1000.00),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000.00),
+ ("2", 3000.00),
+ ],
+ schema=["id", "amount"],
+ )
+
+ assertDataFrameEqual(df1, df2)
Review Comment:
We can add one more test here `assertDataFrameEqual(df1, df2,
ignore_row_order=False)` and it should throw an error
##########
python/pyspark/sql/tests/test_utils.py:
##########
@@ -16,18 +16,498 @@
# limitations under the License.
#
-from pyspark.sql.functions import sha2
+import unittest
+from prettytable import PrettyTable
+
+from pyspark.sql.functions import sha2, to_timestamp
from pyspark.errors import (
AnalysisException,
ParseException,
+ PySparkAssertionError,
IllegalArgumentException,
SparkUpgradeException,
)
-from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.utils import assertDataFrameEqual, blue, red
+from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas,
pandas_requirement_message
+import pyspark.sql.functions as F
from pyspark.sql.functions import to_date, unix_timestamp, from_unixtime
+from pyspark.sql.types import (
+ StringType,
+ ArrayType,
+ LongType,
+ StructType,
+ MapType,
+ FloatType,
+ DoubleType,
+ StructField,
+ TimestampType,
+)
+
+
+class UtilsTestsMixin:
+ def test_assert_equal_inttype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", 3000),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", 3000),
+ ],
+ schema=["id", "amount"],
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_equal_arraytype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("john", ["Python", "Java"]),
+ ("jane", ["Scala", "SQL", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("john", ["Python", "Java"]),
+ ("jane", ["Scala", "SQL", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_approx_equal_arraytype_float(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", [97.01, 89.23]),
+ ("student2", [91.86, 84.34]),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", ArrayType(FloatType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", [97.01, 89.23]),
+ ("student2", [91.86, 84.339999]),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", ArrayType(FloatType()), True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_notequal_arraytype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("John", ["Python", "Java"]),
+ ("Jane", ["Scala", "SQL", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("John", ["Python", "Java"]),
+ ("Jane", ["Scala", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+
+ expected_error_table = PrettyTable(["df", "expected"])
+ expected_error_table.add_row(
+ [red(df1.sort(df1.columns).collect()[0]),
red(df2.sort(df2.columns).collect()[0])]
+ )
+ expected_error_table.add_row(
+ [blue(df1.sort(df1.columns).collect()[1]),
blue(df2.sort(df2.columns).collect()[1])]
+ )
+
+ with self.assertRaises(PySparkAssertionError) as pe:
+ assertDataFrameEqual(df1, df2)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="DIFFERENT_DATAFRAME",
+ message_parameters={"error_table":
expected_error_table.get_string()},
+ )
+
+ def test_assert_equal_maptype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"id": 222342203655477580}),
+ ("student2", {"id": 422322203155477692}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("properties", MapType(StringType(),
LongType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"id": 222342203655477580}),
+ ("student2", {"id": 422322203155477692}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("properties", MapType(StringType(),
LongType()), True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_approx_equal_maptype_double(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.64}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.63999999}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_approx_equal_maptype_double(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.64}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.63999999}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_approx_equal_nested_struct_double(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("jane", (64.57, 76.63, 97.81)),
+ ("john", (93.92, 91.57, 84.36)),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField(
+ "grades",
+ StructType(
+ [
+ StructField("math", DoubleType(), True),
+ StructField("english", DoubleType(), True),
+ StructField("biology", DoubleType(), True),
+ ]
+ ),
+ ),
+ ]
+ ),
+ )
+
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("jane", (64.57, 76.63, 97.81000001)),
+ ("john", (93.92, 91.57, 84.36)),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField(
+ "grades",
+ StructType(
+ [
+ StructField("math", DoubleType(), True),
+ StructField("english", DoubleType(), True),
+ StructField("biology", DoubleType(), True),
+ ]
+ ),
+ ),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_equal_timestamp(self):
+ df1 = self.spark.createDataFrame(
+ data=[("1", "2023-01-01 12:01:01.000")], schema=["id", "timestamp"]
+ )
+
+ df2 = self.spark.createDataFrame(
+ data=[("1", "2023-01-01 12:01:01.000")], schema=["id", "timestamp"]
+ )
+
+ df1 = df1.withColumn("timestamp", to_timestamp("timestamp"))
+ df2 = df2.withColumn("timestamp", to_timestamp("timestamp"))
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_equal_nullrow(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ (None, None),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ (None, None),
+ ],
+ schema=["id", "amount"],
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_notequal_nullval(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", 2000),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", None),
+ ],
+ schema=["id", "amount"],
+ )
+
+ expected_error_table = PrettyTable(["df", "expected"])
+ expected_error_table.add_row(
+ [blue(df1.sort(df1.columns).collect()[0]),
blue(df2.sort(df2.columns).collect()[0])]
+ )
+ expected_error_table.add_row(
+ [red(df1.sort(df1.columns).collect()[1]),
red(df2.sort(df2.columns).collect()[1])]
+ )
+
+ with self.assertRaises(PySparkAssertionError) as pe:
+ assertDataFrameEqual(df1, df2)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="DIFFERENT_DATAFRAME",
+ message_parameters={"error_table":
expected_error_table.get_string()},
+ )
+
+ def test_assert_equal_nulldf(self):
+ df1 = None
+ df2 = None
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_ignore_row_order(self):
+ # test that row order is ignored by default
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("2", 3000.00),
+ ("1", 1000.00),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000.00),
+ ("2", 3000.00),
+ ],
+ schema=["id", "amount"],
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def remove_non_word_characters(self, col):
+ return F.regexp_replace(col, "[^\\w\\s]+", "")
+
+ def test_remove_non_word_characters_long(self):
+ source_data = [("jo&&se",), ("**li**",), ("#::luisa",), (None,)]
+ source_df = self.spark.createDataFrame(source_data, ["name"])
+
+ actual_df = source_df.withColumn(
+ "clean_name", self.remove_non_word_characters(F.col("name"))
+ )
+
+ expected_data = [("jo&&se", "jose"), ("**li**", "li"), ("#::luisa",
"luisa"), (None, None)]
+ expected_df = self.spark.createDataFrame(expected_data, ["name",
"clean_name"])
+
+ assertDataFrameEqual(actual_df, expected_df)
+
+ def test_assert_pyspark_approx_equal(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000.00),
+ ("2", 3000.00),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000.0000001),
+ ("2", 3000.00),
+ ],
+ schema=["id", "amount"],
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_pyspark_df_not_equal(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000.00),
+ ("2", 3000.00),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1001.00),
+ ("2", 3000.00),
+ ],
+ schema=["id", "amount"],
+ )
+
+ expected_error_table = PrettyTable(["df", "expected"])
+ expected_error_table.add_row([red(df1.collect()[0]),
red(df2.collect()[0])])
+ expected_error_table.add_row([blue(df1.collect()[1]),
blue(df2.collect()[1])])
+
+ with self.assertRaises(PySparkAssertionError) as pe:
+ assertDataFrameEqual(df1, df2)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="DIFFERENT_DATAFRAME",
+ message_parameters={"error_table":
expected_error_table.get_string()},
+ )
+
+ def test_assert_notequal_schema(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ (1, 1000),
+ (2, 3000),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", 3000),
+ ],
+ schema=["id", "amount"],
+ )
+
+ with self.assertRaises(PySparkAssertionError) as pe:
+ assertDataFrameEqual(df1, df2)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="DIFFERENT_SCHEMA",
+ message_parameters={"df_schema": df1.schema, "expected_schema":
df2.schema},
+ )
+
+ def test_assert_equal_maptype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"id": 222342203655477580}),
+ ("student2", {"grad_year": 422322203155477692}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("properties", MapType(StringType(),
LongType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"id": 222342203655477580}),
+ ("student2", {"id": 422322203155477692}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("properties", MapType(StringType(),
LongType()), True),
+ ]
+ ),
+ )
+
+ with self.assertRaises(PySparkAssertionError) as pe:
+ assertDataFrameEqual(df1, df2)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="UNSUPPORTED_DATA_TYPE_FOR_IGNORE_ROW_ORDER",
Review Comment:
Hmm do we throw the same error on the Scala side?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]