asl3 commented on code in PR #41606:
URL: https://github.com/apache/spark/pull/41606#discussion_r1253963962
##########
python/pyspark/sql/tests/test_utils.py:
##########
@@ -16,18 +16,498 @@
# limitations under the License.
#
-from pyspark.sql.functions import sha2
+import unittest
+from prettytable import PrettyTable
+
+from pyspark.sql.functions import sha2, to_timestamp
from pyspark.errors import (
AnalysisException,
ParseException,
+ PySparkAssertionError,
IllegalArgumentException,
SparkUpgradeException,
)
-from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.utils import assertDataFrameEqual, blue, red
+from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas,
pandas_requirement_message
+import pyspark.sql.functions as F
from pyspark.sql.functions import to_date, unix_timestamp, from_unixtime
+from pyspark.sql.types import (
+ StringType,
+ ArrayType,
+ LongType,
+ StructType,
+ MapType,
+ FloatType,
+ DoubleType,
+ StructField,
+ TimestampType,
+)
+
+
+class UtilsTestsMixin:
+ def test_assert_equal_inttype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", 3000),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", 3000),
+ ],
+ schema=["id", "amount"],
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_equal_arraytype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("john", ["Python", "Java"]),
+ ("jane", ["Scala", "SQL", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("john", ["Python", "Java"]),
+ ("jane", ["Scala", "SQL", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_approx_equal_arraytype_float(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", [97.01, 89.23]),
+ ("student2", [91.86, 84.34]),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", ArrayType(FloatType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", [97.01, 89.23]),
+ ("student2", [91.86, 84.339999]),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", ArrayType(FloatType()), True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_notequal_arraytype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("John", ["Python", "Java"]),
+ ("Jane", ["Scala", "SQL", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("John", ["Python", "Java"]),
+ ("Jane", ["Scala", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+
+ expected_error_table = PrettyTable(["df", "expected"])
+ expected_error_table.add_row(
+ [red(df1.sort(df1.columns).collect()[0]),
red(df2.sort(df2.columns).collect()[0])]
+ )
+ expected_error_table.add_row(
+ [blue(df1.sort(df1.columns).collect()[1]),
blue(df2.sort(df2.columns).collect()[1])]
+ )
+
+ with self.assertRaises(PySparkAssertionError) as pe:
+ assertDataFrameEqual(df1, df2)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="DIFFERENT_DATAFRAME",
+ message_parameters={"error_table":
expected_error_table.get_string()},
+ )
+
+ def test_assert_equal_maptype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"id": 222342203655477580}),
+ ("student2", {"id": 422322203155477692}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("properties", MapType(StringType(),
LongType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"id": 222342203655477580}),
+ ("student2", {"id": 422322203155477692}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("properties", MapType(StringType(),
LongType()), True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_approx_equal_maptype_double(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.64}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.63999999}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_approx_equal_maptype_double(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.64}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.63999999}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_approx_equal_nested_struct_double(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("jane", (64.57, 76.63, 97.81)),
+ ("john", (93.92, 91.57, 84.36)),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField(
+ "grades",
+ StructType(
+ [
+ StructField("math", DoubleType(), True),
+ StructField("english", DoubleType(), True),
+ StructField("biology", DoubleType(), True),
+ ]
+ ),
+ ),
+ ]
+ ),
+ )
+
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("jane", (64.57, 76.63, 97.81000001)),
+ ("john", (93.92, 91.57, 84.36)),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField(
+ "grades",
+ StructType(
+ [
+ StructField("math", DoubleType(), True),
+ StructField("english", DoubleType(), True),
+ StructField("biology", DoubleType(), True),
+ ]
+ ),
+ ),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_equal_timestamp(self):
+ df1 = self.spark.createDataFrame(
+ data=[("1", "2023-01-01 12:01:01.000")], schema=["id", "timestamp"]
+ )
+
+ df2 = self.spark.createDataFrame(
+ data=[("1", "2023-01-01 12:01:01.000")], schema=["id", "timestamp"]
+ )
+
+ df1 = df1.withColumn("timestamp", to_timestamp("timestamp"))
+ df2 = df2.withColumn("timestamp", to_timestamp("timestamp"))
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_equal_nullrow(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ (None, None),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ (None, None),
+ ],
+ schema=["id", "amount"],
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_notequal_nullval(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", 2000),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", None),
+ ],
+ schema=["id", "amount"],
+ )
+
+ expected_error_table = PrettyTable(["df", "expected"])
+ expected_error_table.add_row(
+ [blue(df1.sort(df1.columns).collect()[0]),
blue(df2.sort(df2.columns).collect()[0])]
+ )
+ expected_error_table.add_row(
+ [red(df1.sort(df1.columns).collect()[1]),
red(df2.sort(df2.columns).collect()[1])]
+ )
+
+ with self.assertRaises(PySparkAssertionError) as pe:
+ assertDataFrameEqual(df1, df2)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="DIFFERENT_DATAFRAME",
+ message_parameters={"error_table":
expected_error_table.get_string()},
+ )
+
+ def test_assert_equal_nulldf(self):
+ df1 = None
+ df2 = None
+
+ assertDataFrameEqual(df1, df2)
Review Comment:
@allisonwang-db sure, for now let's throw an error when any datatype other
than pyspark df is passed in?
in a follow-up PR we can update `assertDataFrameEqual` to work for pandas
and pyspark (calling appropriate helper)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]