asl3 commented on code in PR #41606:
URL: https://github.com/apache/spark/pull/41606#discussion_r1253981371
##########
python/pyspark/sql/tests/test_utils.py:
##########
@@ -16,18 +16,498 @@
# limitations under the License.
#
-from pyspark.sql.functions import sha2
+import unittest
+from prettytable import PrettyTable
+
+from pyspark.sql.functions import sha2, to_timestamp
from pyspark.errors import (
AnalysisException,
ParseException,
+ PySparkAssertionError,
IllegalArgumentException,
SparkUpgradeException,
)
-from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.utils import assertDataFrameEqual, blue, red
+from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas,
pandas_requirement_message
+import pyspark.sql.functions as F
from pyspark.sql.functions import to_date, unix_timestamp, from_unixtime
+from pyspark.sql.types import (
+ StringType,
+ ArrayType,
+ LongType,
+ StructType,
+ MapType,
+ FloatType,
+ DoubleType,
+ StructField,
+ TimestampType,
+)
+
+
+class UtilsTestsMixin:
+ def test_assert_equal_inttype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", 3000),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", 3000),
+ ],
+ schema=["id", "amount"],
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_equal_arraytype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("john", ["Python", "Java"]),
+ ("jane", ["Scala", "SQL", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("john", ["Python", "Java"]),
+ ("jane", ["Scala", "SQL", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_approx_equal_arraytype_float(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", [97.01, 89.23]),
+ ("student2", [91.86, 84.34]),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", ArrayType(FloatType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", [97.01, 89.23]),
+ ("student2", [91.86, 84.339999]),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", ArrayType(FloatType()), True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_notequal_arraytype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("John", ["Python", "Java"]),
+ ("Jane", ["Scala", "SQL", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("John", ["Python", "Java"]),
+ ("Jane", ["Scala", "Java"]),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField("languages", ArrayType(StringType()), True),
+ ]
+ ),
+ )
+
+ expected_error_table = PrettyTable(["df", "expected"])
+ expected_error_table.add_row(
+ [red(df1.sort(df1.columns).collect()[0]),
red(df2.sort(df2.columns).collect()[0])]
+ )
+ expected_error_table.add_row(
+ [blue(df1.sort(df1.columns).collect()[1]),
blue(df2.sort(df2.columns).collect()[1])]
+ )
+
+ with self.assertRaises(PySparkAssertionError) as pe:
+ assertDataFrameEqual(df1, df2)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="DIFFERENT_DATAFRAME",
+ message_parameters={"error_table":
expected_error_table.get_string()},
+ )
+
+ def test_assert_equal_maptype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"id": 222342203655477580}),
+ ("student2", {"id": 422322203155477692}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("properties", MapType(StringType(),
LongType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"id": 222342203655477580}),
+ ("student2", {"id": 422322203155477692}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("properties", MapType(StringType(),
LongType()), True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_approx_equal_maptype_double(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.64}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.63999999}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_approx_equal_maptype_double(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.64}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"math": 76.23, "english": 92.63999999}),
+ ("student2", {"math": 87.89, "english": 84.48}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("grades", MapType(StringType(), DoubleType()),
True),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_approx_equal_nested_struct_double(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("jane", (64.57, 76.63, 97.81)),
+ ("john", (93.92, 91.57, 84.36)),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField(
+ "grades",
+ StructType(
+ [
+ StructField("math", DoubleType(), True),
+ StructField("english", DoubleType(), True),
+ StructField("biology", DoubleType(), True),
+ ]
+ ),
+ ),
+ ]
+ ),
+ )
+
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("jane", (64.57, 76.63, 97.81000001)),
+ ("john", (93.92, 91.57, 84.36)),
+ ],
+ schema=StructType(
+ [
+ StructField("name", StringType(), True),
+ StructField(
+ "grades",
+ StructType(
+ [
+ StructField("math", DoubleType(), True),
+ StructField("english", DoubleType(), True),
+ StructField("biology", DoubleType(), True),
+ ]
+ ),
+ ),
+ ]
+ ),
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_equal_timestamp(self):
+ df1 = self.spark.createDataFrame(
+ data=[("1", "2023-01-01 12:01:01.000")], schema=["id", "timestamp"]
+ )
+
+ df2 = self.spark.createDataFrame(
+ data=[("1", "2023-01-01 12:01:01.000")], schema=["id", "timestamp"]
+ )
+
+ df1 = df1.withColumn("timestamp", to_timestamp("timestamp"))
+ df2 = df2.withColumn("timestamp", to_timestamp("timestamp"))
+
+ assertDataFrameEqual(df1, df2, ignore_row_order=False)
+
+ def test_assert_equal_nullrow(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ (None, None),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ (None, None),
+ ],
+ schema=["id", "amount"],
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_notequal_nullval(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", 2000),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", None),
+ ],
+ schema=["id", "amount"],
+ )
+
+ expected_error_table = PrettyTable(["df", "expected"])
+ expected_error_table.add_row(
+ [blue(df1.sort(df1.columns).collect()[0]),
blue(df2.sort(df2.columns).collect()[0])]
+ )
+ expected_error_table.add_row(
+ [red(df1.sort(df1.columns).collect()[1]),
red(df2.sort(df2.columns).collect()[1])]
+ )
+
+ with self.assertRaises(PySparkAssertionError) as pe:
+ assertDataFrameEqual(df1, df2)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="DIFFERENT_DATAFRAME",
+ message_parameters={"error_table":
expected_error_table.get_string()},
+ )
+
+ def test_assert_equal_nulldf(self):
+ df1 = None
+ df2 = None
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_ignore_row_order(self):
+ # test that row order is ignored by default
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("2", 3000.00),
+ ("1", 1000.00),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000.00),
+ ("2", 3000.00),
+ ],
+ schema=["id", "amount"],
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def remove_non_word_characters(self, col):
+ return F.regexp_replace(col, "[^\\w\\s]+", "")
+
+ def test_remove_non_word_characters_long(self):
+ source_data = [("jo&&se",), ("**li**",), ("#::luisa",), (None,)]
+ source_df = self.spark.createDataFrame(source_data, ["name"])
+
+ actual_df = source_df.withColumn(
+ "clean_name", self.remove_non_word_characters(F.col("name"))
+ )
+
+ expected_data = [("jo&&se", "jose"), ("**li**", "li"), ("#::luisa",
"luisa"), (None, None)]
+ expected_df = self.spark.createDataFrame(expected_data, ["name",
"clean_name"])
+
+ assertDataFrameEqual(actual_df, expected_df)
+
+ def test_assert_pyspark_approx_equal(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000.00),
+ ("2", 3000.00),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000.0000001),
+ ("2", 3000.00),
+ ],
+ schema=["id", "amount"],
+ )
+
+ assertDataFrameEqual(df1, df2)
+
+ def test_assert_pyspark_df_not_equal(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000.00),
+ ("2", 3000.00),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1001.00),
+ ("2", 3000.00),
+ ],
+ schema=["id", "amount"],
+ )
+
+ expected_error_table = PrettyTable(["df", "expected"])
+ expected_error_table.add_row([red(df1.collect()[0]),
red(df2.collect()[0])])
+ expected_error_table.add_row([blue(df1.collect()[1]),
blue(df2.collect()[1])])
+
+ with self.assertRaises(PySparkAssertionError) as pe:
+ assertDataFrameEqual(df1, df2)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="DIFFERENT_DATAFRAME",
+ message_parameters={"error_table":
expected_error_table.get_string()},
+ )
+
+ def test_assert_notequal_schema(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ (1, 1000),
+ (2, 3000),
+ ],
+ schema=["id", "amount"],
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("1", 1000),
+ ("2", 3000),
+ ],
+ schema=["id", "amount"],
+ )
+
+ with self.assertRaises(PySparkAssertionError) as pe:
+ assertDataFrameEqual(df1, df2)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="DIFFERENT_SCHEMA",
+ message_parameters={"df_schema": df1.schema, "expected_schema":
df2.schema},
+ )
+
+ def test_assert_equal_maptype(self):
+ df1 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"id": 222342203655477580}),
+ ("student2", {"grad_year": 422322203155477692}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("properties", MapType(StringType(),
LongType()), True),
+ ]
+ ),
+ )
+ df2 = self.spark.createDataFrame(
+ data=[
+ ("student1", {"id": 222342203655477580}),
+ ("student2", {"id": 422322203155477692}),
+ ],
+ schema=StructType(
+ [
+ StructField("student", StringType(), True),
+ StructField("properties", MapType(StringType(),
LongType()), True),
+ ]
+ ),
+ )
+
+ with self.assertRaises(PySparkAssertionError) as pe:
+ assertDataFrameEqual(df1, df2)
+
+ self.check_error(
+ exception=pe.exception,
+ error_class="UNSUPPORTED_DATA_TYPE_FOR_IGNORE_ROW_ORDER",
+ message_parameters={},
+ )
Review Comment:
sounds good! for the last case (invalid sql query), this is the Error
Message I get. looks like this is thrown before reaching my
`assertDataFrameEqual` function
`======================================================================
ERROR [0.077s]: test_spark_sql_invalid
(pyspark.sql.tests.test_utils.UtilsTests.test_spark_sql_invalid)
----------------------------------------------------------------------
Traceback (most recent call last):
File
"/Users/amanda.liu/Documents/Databricks/spark/python/pyspark/sql/tests/test_utils.py",
line 734, in test_spark_sql_invalid
self.spark.sql("select name from df1"),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File
"/Users/amanda.liu/Documents/Databricks/spark/python/pyspark/sql/session.py",
line 1527, in sql
return DataFrame(self._jsparkSession.sql(sqlQuery, litArgs), self)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File
"/Users/amanda.liu/Documents/Databricks/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/java_gateway.py",
line 1322, in __call__
return_value = get_return_value(
^^^^^^^^^^^^^^^^^
File
"/Users/amanda.liu/Documents/Databricks/spark/python/pyspark/errors/exceptions/captured.py",
line 185, in deco
raise converted from None
pyspark.errors.exceptions.captured.AnalysisException:
[UNRESOLVED_COLUMN.WITH_SUGGESTION] A column or function parameter with name
`name` cannot be resolved. Did you mean one of the following? [`id`,
`amount`].; line 1 pos 7;
'Project ['name]
+- SubqueryAlias df1
+- View (`df1`, [id#165L,amount#166L])
+- LogicalRDD [id#165L, amount#166L], false`
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]