This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 78eda439c99 [SPARK-41783][SPARK-41770][CONNECT][PYTHON] Make column op 
support None
78eda439c99 is described below

commit 78eda439c999c59d9cc91fa541366c055094cd9e
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Sat Dec 31 09:51:38 2022 +0900

    [SPARK-41783][SPARK-41770][CONNECT][PYTHON] Make column op support None
    
    ### What changes were proposed in this pull request?
    Make column op support None
    
    ### Why are the changes needed?
    to be consistent with PySpark
    
    ### Does this PR introduce _any_ user-facing change?
    yes
    
    ### How was this patch tested?
    added UT
    
    Closes #39302 from zhengruifeng/connect_column_none.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 python/pyspark/sql/connect/column.py               |  7 ++--
 .../sql/tests/connect/test_connect_column.py       | 44 ++++++++++++++++++++--
 2 files changed, 44 insertions(+), 7 deletions(-)

diff --git a/python/pyspark/sql/connect/column.py 
b/python/pyspark/sql/connect/column.py
index 9be202145f2..e694b71754e 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -72,7 +72,7 @@ def _bin_op(
     def wrapped(self: "Column", other: Any) -> "Column":
         from pyspark.sql.connect.functions import lit
 
-        if isinstance(
+        if other is None or isinstance(
             other, (bool, float, int, str, datetime.datetime, datetime.date, 
decimal.Decimal)
         ):
             other = lit(other)
@@ -255,7 +255,7 @@ class Column:
         """
         from pyspark.sql.connect.functions import lit
 
-        if isinstance(
+        if other is None or isinstance(
             other, (bool, float, int, str, datetime.datetime, datetime.date, 
decimal.Decimal)
         ):
             other = lit(other)
@@ -452,7 +452,8 @@ def _test() -> None:
         del pyspark.sql.connect.column.Column.bitwiseAND.__doc__
         del pyspark.sql.connect.column.Column.bitwiseOR.__doc__
         del pyspark.sql.connect.column.Column.bitwiseXOR.__doc__
-        # TODO(SPARK-41770): eqNullSafe does not support None as its argument
+        # TODO(SPARK-41745): SparkSession.createDataFrame does not respect the 
column names in
+        #  the row
         del pyspark.sql.connect.column.Column.eqNullSafe.__doc__
         # TODO(SPARK-41745): SparkSession.createDataFrame does not respect the 
column names in
         #  the row
diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py 
b/python/pyspark/sql/tests/connect/test_connect_column.py
index 9d18a1fe9b2..f7ce6de8922 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column.py
@@ -228,8 +228,8 @@ class SparkConnectTests(SparkConnectSQLTestCase):
         # |               null|2022-12-22|null|
         # +-------------------+----------+----+
 
-        cdf = self.spark.sql(query)
-        sdf = self.connect.sql(query)
+        cdf = self.connect.sql(query)
+        sdf = self.spark.sql(query)
 
         # datetime.date
         self.assert_eq(
@@ -286,8 +286,8 @@ class SparkConnectTests(SparkConnectSQLTestCase):
         # |  3|   3|  4| 3.5|
         # +---+----+---+----+
 
-        cdf = self.spark.sql(query)
-        sdf = self.connect.sql(query)
+        cdf = self.connect.sql(query)
+        sdf = self.spark.sql(query)
 
         self.assert_eq(
             cdf.select(cdf.a < decimal.Decimal(3)).toPandas(),
@@ -310,6 +310,42 @@ class SparkConnectTests(SparkConnectSQLTestCase):
             sdf.select(sdf.d >= decimal.Decimal(3.0)).toPandas(),
         )
 
+    def test_none(self):
+        # SPARK-41783: test none
+
+        from pyspark.sql import functions as SF
+        from pyspark.sql.connect import functions as CF
+
+        query = """
+            SELECT * FROM VALUES
+            (1, 1, NULL), (2, NULL, 1), (NULL, 3, 4)
+            AS tab(a, b, c)
+            """
+
+        # +----+----+----+
+        # |   a|   b|   c|
+        # +----+----+----+
+        # |   1|   1|null|
+        # |   2|null|   1|
+        # |null|   3|   4|
+        # +----+----+----+
+
+        cdf = self.connect.sql(query)
+        sdf = self.spark.sql(query)
+
+        self.assert_eq(
+            cdf.select(cdf.b > None, CF.col("c") >= None).toPandas(),
+            sdf.select(sdf.b > None, SF.col("c") >= None).toPandas(),
+        )
+        self.assert_eq(
+            cdf.select(cdf.b < None, CF.col("c") <= None).toPandas(),
+            sdf.select(sdf.b < None, SF.col("c") <= None).toPandas(),
+        )
+        self.assert_eq(
+            cdf.select(cdf.b.eqNullSafe(None), 
CF.col("c").eqNullSafe(None)).toPandas(),
+            sdf.select(sdf.b.eqNullSafe(None), 
SF.col("c").eqNullSafe(None)).toPandas(),
+        )
+
     def test_simple_binary_expressions(self):
         """Test complex expression"""
         df = self.connect.read.table(self.tbl_name)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to