This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e79e8797e44 [SPARK-44827][PYTHON][TESTS] Fix test when ansi mode
enabled
e79e8797e44 is described below
commit e79e8797e4467c85e5ff5ad5b49631a3177a461b
Author: panbingkun <[email protected]>
AuthorDate: Fri Aug 25 09:09:26 2023 +0900
[SPARK-44827][PYTHON][TESTS] Fix test when ansi mode enabled
### What changes were proposed in this pull request?
The pr aims to fix some UT when SPARK_ANSI_SQL_MODE=true, include:
- test_assert_approx_equal_decimaltype_custom_rtol_pass
- functions.to_unix_timestamp
- DataFrame.union
### Why are the changes needed?
Make pyspark test happy.
When Ansi workflow daily ga runs, the following error occurs, eg:
https://github.com/apache/spark/actions/runs/5873530086/job/15926967006

### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
- Pass GA.
- Manually test:
```
(base) panbingkun:~/Developer/spark/spark-community$export
SPARK_ANSI_SQL_MODE=true
(base) panbingkun:~/Developer/spark/spark-community$python/run-tests
--testnames 'pyspark.sql.tests.connect.test_utils
ConnectUtilsTests.test_assert_approx_equal_decimaltype_custom_rtol_pass'
Running PySpark tests. Output is in
/Users/panbingkun/Developer/spark/spark-community/python/unit-tests.log
Will test against the following Python executables: ['python3.9']
Will test the following Python tests:
['pyspark.sql.tests.connect.test_utils
ConnectUtilsTests.test_assert_approx_equal_decimaltype_custom_rtol_pass']
python3.9 python_implementation is CPython
python3.9 version is: Python 3.9.13
Starting test(python3.9): pyspark.sql.tests.connect.test_utils
ConnectUtilsTests.test_assert_approx_equal_decimaltype_custom_rtol_pass (temp
output:
/Users/panbingkun/Developer/spark/spark-community/python/target/b59e563b-ac28-4279-ae95-462cde8f19c3/python3.9__pyspark.sql.tests.connect.test_utils_ConnectUtilsTests.test_assert_approx_equal_decimaltype_custom_rtol_pass__9ypt4lse.log)
Finished test(python3.9): pyspark.sql.tests.connect.test_utils
ConnectUtilsTests.test_assert_approx_equal_decimaltype_custom_rtol_pass (8s)
Tests passed in 8 seconds
```
Closes #42513 from panbingkun/SPARK-44827.
Lead-authored-by: panbingkun <[email protected]>
Co-authored-by: panbingkun <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/dataframe.py | 20 ++++++++++----------
python/pyspark/sql/functions.py | 6 ------
python/pyspark/sql/tests/test_utils.py | 4 ++--
python/pyspark/testing/utils.py | 8 ++++++--
4 files changed, 18 insertions(+), 20 deletions(-)
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 2ca5be76e3b..8eaac594b40 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -3956,20 +3956,20 @@ class DataFrame(PandasMapOpsMixin,
PandasConversionMixin):
Example 2: Combining two DataFrames with different schemas
>>> from pyspark.sql.functions import lit
- >>> df1 = spark.createDataFrame([("Alice", 1), ("Bob", 2)], ["name",
"id"])
- >>> df2 = spark.createDataFrame([(3, "Charlie"), (4, "Dave")], ["id",
"name"])
+ >>> df1 = spark.createDataFrame([(100001, 1), (100002, 2)], schema="id
LONG, money INT")
+ >>> df2 = spark.createDataFrame([(3, 100003), (4, 100003)],
schema="money INT, id LONG")
>>> df1 = df1.withColumn("age", lit(30))
>>> df2 = df2.withColumn("age", lit(40))
>>> df3 = df1.union(df2)
>>> df3.show()
- +-----+-------+---+
- | name| id|age|
- +-----+-------+---+
- |Alice| 1| 30|
- | Bob| 2| 30|
- | 3|Charlie| 40|
- | 4| Dave| 40|
- +-----+-------+---+
+ +------+------+---+
+ | id| money|age|
+ +------+------+---+
+ |100001| 1| 30|
+ |100002| 2| 30|
+ | 3|100003| 40|
+ | 4|100003| 40|
+ +------+------+---+
Example 3: Combining two DataFrames with mismatched columns
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index e580d2aba12..535ac06530a 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -7843,12 +7843,6 @@ def to_unix_timestamp(
>>> df.select(to_unix_timestamp(df.e,
lit("yyyy-MM-dd")).alias('r')).collect()
[Row(r=1460098800)]
>>> spark.conf.unset("spark.sql.session.timeZone")
-
- >>> spark.conf.set("spark.sql.session.timeZone", "America/Los_Angeles")
- >>> df = spark.createDataFrame([("2016-04-08",)], ["e"])
- >>> df.select(to_unix_timestamp(df.e).alias('r')).collect()
- [Row(r=None)]
- >>> spark.conf.unset("spark.sql.session.timeZone")
"""
if format is not None:
return _invoke_function_over_columns("to_unix_timestamp", timestamp,
format)
diff --git a/python/pyspark/sql/tests/test_utils.py
b/python/pyspark/sql/tests/test_utils.py
index e1b7f298d0a..a2cad4e83bd 100644
--- a/python/pyspark/sql/tests/test_utils.py
+++ b/python/pyspark/sql/tests/test_utils.py
@@ -273,8 +273,8 @@ class UtilsTestsMixin:
)
# cast to DecimalType
- df1 = df1.withColumn("col_1", F.col("grade").cast("decimal(4,3)"))
- df2 = df2.withColumn("col_1", F.col("grade").cast("decimal(4,3)"))
+ df1 = df1.withColumn("col_1", F.col("grade").cast("decimal(5,3)"))
+ df2 = df2.withColumn("col_1", F.col("grade").cast("decimal(5,3)"))
assertDataFrameEqual(df1, df2, rtol=1e-1)
diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py
index 7dd723634e2..2a508b8a450 100644
--- a/python/pyspark/testing/utils.py
+++ b/python/pyspark/testing/utils.py
@@ -21,6 +21,7 @@ import struct
import sys
import unittest
import difflib
+from decimal import Decimal
from time import time, sleep
from typing import (
Any,
@@ -411,8 +412,8 @@ def assertDataFrameEqual(
Note that schema equality is checked only when `expected` is a DataFrame
(not a list of Rows).
- For DataFrames with float values, assertDataFrame asserts approximate
equality.
- Two float values a and b are approximately equal if the following equation
is True:
+ For DataFrames with float/decimal values, assertDataFrame asserts
approximate equality.
+ Two float/decimal values a and b are approximately equal if the following
equation is True:
``absolute(a - b) <= (atol + rtol * absolute(b))``.
@@ -539,6 +540,9 @@ def assertDataFrameEqual(
elif isinstance(val1, float) and isinstance(val2, float):
if abs(val1 - val2) > (atol + rtol * abs(val2)):
return False
+ elif isinstance(val1, Decimal) and isinstance(val2, Decimal):
+ if abs(val1 - val2) > (Decimal(atol) + Decimal(rtol) *
abs(val2)):
+ return False
else:
if val1 != val2:
return False
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]