asl3 commented on code in PR #41606:
URL: https://github.com/apache/spark/pull/41606#discussion_r1253985011


##########
python/pyspark/sql/tests/test_utils.py:
##########
@@ -16,18 +16,498 @@
 # limitations under the License.
 #
 
-from pyspark.sql.functions import sha2
+import unittest
+from prettytable import PrettyTable
+
+from pyspark.sql.functions import sha2, to_timestamp
 from pyspark.errors import (
     AnalysisException,
     ParseException,
+    PySparkAssertionError,
     IllegalArgumentException,
     SparkUpgradeException,
 )
-from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.utils import assertDataFrameEqual, blue, red
+from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, 
pandas_requirement_message
+import pyspark.sql.functions as F
 from pyspark.sql.functions import to_date, unix_timestamp, from_unixtime
+from pyspark.sql.types import (
+    StringType,
+    ArrayType,
+    LongType,
+    StructType,
+    MapType,
+    FloatType,
+    DoubleType,
+    StructField,
+    TimestampType,
+)
+
+
+class UtilsTestsMixin:
+    def test_assert_equal_inttype(self):
+        df1 = self.spark.createDataFrame(
+            data=[
+                ("1", 1000),
+                ("2", 3000),
+            ],
+            schema=["id", "amount"],
+        )
+        df2 = self.spark.createDataFrame(
+            data=[
+                ("1", 1000),
+                ("2", 3000),
+            ],
+            schema=["id", "amount"],
+        )
+
+        assertDataFrameEqual(df1, df2)
+
+    def test_assert_equal_arraytype(self):
+        df1 = self.spark.createDataFrame(
+            data=[
+                ("john", ["Python", "Java"]),
+                ("jane", ["Scala", "SQL", "Java"]),
+            ],
+            schema=StructType(
+                [
+                    StructField("name", StringType(), True),
+                    StructField("languages", ArrayType(StringType()), True),
+                ]
+            ),
+        )
+        df2 = self.spark.createDataFrame(
+            data=[
+                ("john", ["Python", "Java"]),
+                ("jane", ["Scala", "SQL", "Java"]),
+            ],
+            schema=StructType(
+                [
+                    StructField("name", StringType(), True),
+                    StructField("languages", ArrayType(StringType()), True),
+                ]
+            ),
+        )
+
+        assertDataFrameEqual(df1, df2)
+
+    def test_assert_approx_equal_arraytype_float(self):
+        df1 = self.spark.createDataFrame(
+            data=[
+                ("student1", [97.01, 89.23]),
+                ("student2", [91.86, 84.34]),
+            ],
+            schema=StructType(
+                [
+                    StructField("student", StringType(), True),
+                    StructField("grades", ArrayType(FloatType()), True),
+                ]
+            ),
+        )
+        df2 = self.spark.createDataFrame(
+            data=[
+                ("student1", [97.01, 89.23]),
+                ("student2", [91.86, 84.339999]),
+            ],
+            schema=StructType(
+                [
+                    StructField("student", StringType(), True),
+                    StructField("grades", ArrayType(FloatType()), True),
+                ]
+            ),
+        )
+
+        assertDataFrameEqual(df1, df2)
+
+    def test_assert_notequal_arraytype(self):
+        df1 = self.spark.createDataFrame(
+            data=[
+                ("John", ["Python", "Java"]),
+                ("Jane", ["Scala", "SQL", "Java"]),
+            ],
+            schema=StructType(
+                [
+                    StructField("name", StringType(), True),
+                    StructField("languages", ArrayType(StringType()), True),
+                ]
+            ),
+        )
+        df2 = self.spark.createDataFrame(
+            data=[
+                ("John", ["Python", "Java"]),
+                ("Jane", ["Scala", "Java"]),
+            ],
+            schema=StructType(
+                [
+                    StructField("name", StringType(), True),
+                    StructField("languages", ArrayType(StringType()), True),
+                ]
+            ),
+        )
+
+        expected_error_table = PrettyTable(["df", "expected"])
+        expected_error_table.add_row(
+            [red(df1.sort(df1.columns).collect()[0]), 
red(df2.sort(df2.columns).collect()[0])]
+        )
+        expected_error_table.add_row(
+            [blue(df1.sort(df1.columns).collect()[1]), 
blue(df2.sort(df2.columns).collect()[1])]
+        )
+
+        with self.assertRaises(PySparkAssertionError) as pe:
+            assertDataFrameEqual(df1, df2)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="DIFFERENT_DATAFRAME",
+            message_parameters={"error_table": 
expected_error_table.get_string()},
+        )
+
+    def test_assert_equal_maptype(self):
+        df1 = self.spark.createDataFrame(
+            data=[
+                ("student1", {"id": 222342203655477580}),
+                ("student2", {"id": 422322203155477692}),
+            ],
+            schema=StructType(
+                [
+                    StructField("student", StringType(), True),
+                    StructField("properties", MapType(StringType(), 
LongType()), True),
+                ]
+            ),
+        )
+        df2 = self.spark.createDataFrame(
+            data=[
+                ("student1", {"id": 222342203655477580}),
+                ("student2", {"id": 422322203155477692}),
+            ],
+            schema=StructType(
+                [
+                    StructField("student", StringType(), True),
+                    StructField("properties", MapType(StringType(), 
LongType()), True),
+                ]
+            ),
+        )
+
+        assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+    def test_assert_approx_equal_maptype_double(self):
+        df1 = self.spark.createDataFrame(
+            data=[
+                ("student1", {"math": 76.23, "english": 92.64}),
+                ("student2", {"math": 87.89, "english": 84.48}),
+            ],
+            schema=StructType(
+                [
+                    StructField("student", StringType(), True),
+                    StructField("grades", MapType(StringType(), DoubleType()), 
True),
+                ]
+            ),
+        )
+        df2 = self.spark.createDataFrame(
+            data=[
+                ("student1", {"math": 76.23, "english": 92.63999999}),
+                ("student2", {"math": 87.89, "english": 84.48}),
+            ],
+            schema=StructType(
+                [
+                    StructField("student", StringType(), True),
+                    StructField("grades", MapType(StringType(), DoubleType()), 
True),
+                ]
+            ),
+        )
+
+        assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+    def test_assert_approx_equal_maptype_double(self):
+        df1 = self.spark.createDataFrame(
+            data=[
+                ("student1", {"math": 76.23, "english": 92.64}),
+                ("student2", {"math": 87.89, "english": 84.48}),
+            ],
+            schema=StructType(
+                [
+                    StructField("student", StringType(), True),
+                    StructField("grades", MapType(StringType(), DoubleType()), 
True),
+                ]
+            ),
+        )
+        df2 = self.spark.createDataFrame(
+            data=[
+                ("student1", {"math": 76.23, "english": 92.63999999}),
+                ("student2", {"math": 87.89, "english": 84.48}),
+            ],
+            schema=StructType(
+                [
+                    StructField("student", StringType(), True),
+                    StructField("grades", MapType(StringType(), DoubleType()), 
True),
+                ]
+            ),
+        )
+
+        assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+    def test_assert_approx_equal_nested_struct_double(self):
+        df1 = self.spark.createDataFrame(
+            data=[
+                ("jane", (64.57, 76.63, 97.81)),
+                ("john", (93.92, 91.57, 84.36)),
+            ],
+            schema=StructType(
+                [
+                    StructField("name", StringType(), True),
+                    StructField(
+                        "grades",
+                        StructType(
+                            [
+                                StructField("math", DoubleType(), True),
+                                StructField("english", DoubleType(), True),
+                                StructField("biology", DoubleType(), True),
+                            ]
+                        ),
+                    ),
+                ]
+            ),
+        )
+
+        df2 = self.spark.createDataFrame(
+            data=[
+                ("jane", (64.57, 76.63, 97.81000001)),
+                ("john", (93.92, 91.57, 84.36)),
+            ],
+            schema=StructType(
+                [
+                    StructField("name", StringType(), True),
+                    StructField(
+                        "grades",
+                        StructType(
+                            [
+                                StructField("math", DoubleType(), True),
+                                StructField("english", DoubleType(), True),
+                                StructField("biology", DoubleType(), True),
+                            ]
+                        ),
+                    ),
+                ]
+            ),
+        )
+
+        assertDataFrameEqual(df1, df2)
+
+    def test_assert_equal_timestamp(self):
+        df1 = self.spark.createDataFrame(
+            data=[("1", "2023-01-01 12:01:01.000")], schema=["id", "timestamp"]
+        )
+
+        df2 = self.spark.createDataFrame(
+            data=[("1", "2023-01-01 12:01:01.000")], schema=["id", "timestamp"]
+        )
+
+        df1 = df1.withColumn("timestamp", to_timestamp("timestamp"))
+        df2 = df2.withColumn("timestamp", to_timestamp("timestamp"))
+
+        assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+    def test_assert_equal_nullrow(self):
+        df1 = self.spark.createDataFrame(
+            data=[
+                ("1", 1000),
+                (None, None),
+            ],
+            schema=["id", "amount"],
+        )
+        df2 = self.spark.createDataFrame(
+            data=[
+                ("1", 1000),
+                (None, None),
+            ],
+            schema=["id", "amount"],
+        )
+
+        assertDataFrameEqual(df1, df2)
+
+    def test_assert_notequal_nullval(self):
+        df1 = self.spark.createDataFrame(
+            data=[
+                ("1", 1000),
+                ("2", 2000),
+            ],
+            schema=["id", "amount"],
+        )
+        df2 = self.spark.createDataFrame(
+            data=[
+                ("1", 1000),
+                ("2", None),
+            ],
+            schema=["id", "amount"],
+        )
+
+        expected_error_table = PrettyTable(["df", "expected"])
+        expected_error_table.add_row(
+            [blue(df1.sort(df1.columns).collect()[0]), 
blue(df2.sort(df2.columns).collect()[0])]
+        )
+        expected_error_table.add_row(
+            [red(df1.sort(df1.columns).collect()[1]), 
red(df2.sort(df2.columns).collect()[1])]
+        )
+
+        with self.assertRaises(PySparkAssertionError) as pe:
+            assertDataFrameEqual(df1, df2)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="DIFFERENT_DATAFRAME",
+            message_parameters={"error_table": 
expected_error_table.get_string()},
+        )
+
+    def test_assert_equal_nulldf(self):
+        df1 = None
+        df2 = None
+
+        assertDataFrameEqual(df1, df2)
+
+    def test_ignore_row_order(self):
+        # test that row order is ignored by default
+        df1 = self.spark.createDataFrame(
+            data=[
+                ("2", 3000.00),
+                ("1", 1000.00),
+            ],
+            schema=["id", "amount"],
+        )
+        df2 = self.spark.createDataFrame(
+            data=[
+                ("1", 1000.00),
+                ("2", 3000.00),
+            ],
+            schema=["id", "amount"],
+        )
+
+        assertDataFrameEqual(df1, df2)
+
+    def remove_non_word_characters(self, col):
+        return F.regexp_replace(col, "[^\\w\\s]+", "")
+
+    def test_remove_non_word_characters_long(self):
+        source_data = [("jo&&se",), ("**li**",), ("#::luisa",), (None,)]
+        source_df = self.spark.createDataFrame(source_data, ["name"])
+
+        actual_df = source_df.withColumn(
+            "clean_name", self.remove_non_word_characters(F.col("name"))
+        )
+
+        expected_data = [("jo&&se", "jose"), ("**li**", "li"), ("#::luisa", 
"luisa"), (None, None)]
+        expected_df = self.spark.createDataFrame(expected_data, ["name", 
"clean_name"])
+
+        assertDataFrameEqual(actual_df, expected_df)
+
+    def test_assert_pyspark_approx_equal(self):
+        df1 = self.spark.createDataFrame(
+            data=[
+                ("1", 1000.00),
+                ("2", 3000.00),
+            ],
+            schema=["id", "amount"],
+        )
+        df2 = self.spark.createDataFrame(
+            data=[
+                ("1", 1000.0000001),
+                ("2", 3000.00),
+            ],
+            schema=["id", "amount"],
+        )
+
+        assertDataFrameEqual(df1, df2)
+
+    def test_assert_pyspark_df_not_equal(self):
+        df1 = self.spark.createDataFrame(
+            data=[
+                ("1", 1000.00),
+                ("2", 3000.00),
+            ],
+            schema=["id", "amount"],
+        )
+        df2 = self.spark.createDataFrame(
+            data=[
+                ("1", 1001.00),
+                ("2", 3000.00),
+            ],
+            schema=["id", "amount"],
+        )
+
+        expected_error_table = PrettyTable(["df", "expected"])
+        expected_error_table.add_row([red(df1.collect()[0]), 
red(df2.collect()[0])])
+        expected_error_table.add_row([blue(df1.collect()[1]), 
blue(df2.collect()[1])])
+
+        with self.assertRaises(PySparkAssertionError) as pe:
+            assertDataFrameEqual(df1, df2)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="DIFFERENT_DATAFRAME",
+            message_parameters={"error_table": 
expected_error_table.get_string()},
+        )
+
+    def test_assert_notequal_schema(self):
+        df1 = self.spark.createDataFrame(
+            data=[
+                (1, 1000),
+                (2, 3000),
+            ],
+            schema=["id", "amount"],
+        )
+        df2 = self.spark.createDataFrame(
+            data=[
+                ("1", 1000),
+                ("2", 3000),
+            ],
+            schema=["id", "amount"],
+        )
+
+        with self.assertRaises(PySparkAssertionError) as pe:
+            assertDataFrameEqual(df1, df2)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="DIFFERENT_SCHEMA",
+            message_parameters={"df_schema": df1.schema, "expected_schema": 
df2.schema},
+        )
+
+    def test_assert_equal_maptype(self):
+        df1 = self.spark.createDataFrame(
+            data=[
+                ("student1", {"id": 222342203655477580}),
+                ("student2", {"grad_year": 422322203155477692}),
+            ],
+            schema=StructType(
+                [
+                    StructField("student", StringType(), True),
+                    StructField("properties", MapType(StringType(), 
LongType()), True),
+                ]
+            ),
+        )
+        df2 = self.spark.createDataFrame(
+            data=[
+                ("student1", {"id": 222342203655477580}),
+                ("student2", {"id": 422322203155477692}),
+            ],
+            schema=StructType(
+                [
+                    StructField("student", StringType(), True),
+                    StructField("properties", MapType(StringType(), 
LongType()), True),
+                ]
+            ),
+        )
+
+        with self.assertRaises(PySparkAssertionError) as pe:
+            assertDataFrameEqual(df1, df2)
+
+        self.check_error(
+            exception=pe.exception,
+            error_class="UNSUPPORTED_DATA_TYPE_FOR_IGNORE_ROW_ORDER",

Review Comment:
   hm i don't see [the Scala 
side](https://github.com/apache/spark/blob/master/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala)
 catching for this same error, should we still keep?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to