This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e9d31a0a1dd [SPARK-41875][CONNECT][PYTHON] Add test cases for 
`Dataset.to()`
e9d31a0a1dd is described below

commit e9d31a0a1dd54900207f92760b63bdf53f6688b4
Author: Jiaan Geng <belie...@163.com>
AuthorDate: Sat Jan 7 09:43:49 2023 +0800

    [SPARK-41875][CONNECT][PYTHON] Add test cases for `Dataset.to()`
    
    ### What changes were proposed in this pull request?
    
    1. This PR let the parameter of `Dataset.to()` the same as pyspark.
    
    2. The connect's `Dataset.to()` lost some test cases.
    This PR adds these test cases that refer 
https://github.com/apache/spark/blob/89666d44a39c48df841a0102ff6f54eaeb4c6140/python/pyspark/sql/tests/test_dataframe.py#L1464
    
    ### Why are the changes needed?
    This PR adds these test cases for connect's `Dataset.to()`.
    
    ### Does this PR introduce _any_ user-facing change?
    'No'.
    New feature.
    
    ### How was this patch tested?
    New test cases.
    
    Closes #39422 from beliefer/SPARK-41875.
    
    Authored-by: Jiaan Geng <belie...@163.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/sql/connect/dataframe.py            |   4 +-
 .../sql/tests/connect/test_connect_basic.py        | 116 ++++++++++++---------
 2 files changed, 69 insertions(+), 51 deletions(-)

diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index 8aca9fbb968..17b88461a43 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -40,7 +40,7 @@ from collections.abc import Iterable
 
 from pyspark import _NoValue, SparkContext, SparkConf
 from pyspark._globals import _NoValueType
-from pyspark.sql.types import DataType, StructType, Row
+from pyspark.sql.types import StructType, Row
 
 import pyspark.sql.connect.plan as plan
 from pyspark.sql.connect.group import GroupedData
@@ -1210,7 +1210,7 @@ class DataFrame:
 
     inputFiles.__doc__ = PySparkDataFrame.inputFiles.__doc__
 
-    def to(self, schema: DataType) -> "DataFrame":
+    def to(self, schema: StructType) -> "DataFrame":
         assert schema is not None
         return DataFrame.withPlan(
             plan.ToSchema(child=self._plan, schema=schema),
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 72e60712b98..31a7e6fdbad 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -30,7 +30,6 @@ from pyspark.sql.types import (
     ArrayType,
     Row,
 )
-import pyspark.sql.functions
 from pyspark.testing.utils import ReusedPySparkTestCase
 from pyspark.testing.connectutils import should_test_connect, 
connect_requirement_message
 from pyspark.testing.pandasutils import PandasOnSparkTestCase
@@ -43,9 +42,11 @@ if should_test_connect:
     from pyspark.sql.connect.session import SparkSession as RemoteSparkSession
     from pyspark.sql.connect.client import ChannelBuilder
     from pyspark.sql.connect.column import Column
+    from pyspark.sql.dataframe import DataFrame
     from pyspark.sql.connect.dataframe import DataFrame as CDataFrame
     from pyspark.sql.connect.function_builder import udf
-    from pyspark.sql.connect.functions import lit, col
+    from pyspark.sql import functions as SF
+    from pyspark.sql.connect import functions as CF
 
 
 @unittest.skipIf(not should_test_connect, connect_requirement_message)
@@ -333,7 +334,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         """SPARK-41114: Test creating a dataframe using local data"""
         pdf = pd.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]})
         df = self.connect.createDataFrame(pdf)
-        rows = df.filter(df.a == lit(3)).collect()
+        rows = df.filter(df.a == CF.lit(3)).collect()
         self.assertTrue(len(rows) == 1)
         self.assertEqual(rows[0][0], 3)
         self.assertEqual(rows[0][1], "c")
@@ -679,6 +680,15 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
     def test_to(self):
         # SPARK-41464: test DataFrame.to()
 
+        cdf = self.connect.read.table(self.tbl_name)
+        df = self.spark.read.table(self.tbl_name)
+
+        def assert_eq_schema(cdf: CDataFrame, df: DataFrame, schema: 
StructType):
+            cdf_to = cdf.to(schema)
+            df_to = df.to(schema)
+            self.assertEqual(cdf_to.schema, df_to.schema)
+            self.assert_eq(cdf_to.toPandas(), df_to.toPandas())
+
         # The schema has not changed
         schema = StructType(
             [
@@ -687,11 +697,15 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             ]
         )
 
-        cdf = self.connect.read.table(self.tbl_name).to(schema)
-        df = self.spark.read.table(self.tbl_name).to(schema)
+        assert_eq_schema(cdf, df, schema)
+
+        # Change schema with struct
+        schema2 = StructType([StructField("struct", schema, False)])
+
+        cdf_to = cdf.select(CF.struct("id", 
"name").alias("struct")).to(schema2)
+        df_to = df.select(SF.struct("id", "name").alias("struct")).to(schema2)
 
-        self.assertEqual(cdf.schema, df.schema)
-        self.assert_eq(cdf.toPandas(), df.toPandas())
+        self.assertEqual(cdf_to.schema, df_to.schema)
 
         # Change the column name
         schema = StructType(
@@ -701,11 +715,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             ]
         )
 
-        cdf = self.connect.read.table(self.tbl_name).to(schema)
-        df = self.spark.read.table(self.tbl_name).to(schema)
-
-        self.assertEqual(cdf.schema, df.schema)
-        self.assert_eq(cdf.toPandas(), df.toPandas())
+        assert_eq_schema(cdf, df, schema)
 
         # Change the column data type
         schema = StructType(
@@ -715,26 +725,44 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             ]
         )
 
-        cdf = self.connect.read.table(self.tbl_name).to(schema)
-        df = self.spark.read.table(self.tbl_name).to(schema)
+        assert_eq_schema(cdf, df, schema)
+
+        # Reduce the column quantity and change data type
+        schema = StructType(
+            [
+                StructField("id", LongType(), True),
+            ]
+        )
+
+        assert_eq_schema(cdf, df, schema)
+
+        # incompatible field nullability
+        schema = StructType([StructField("id", LongType(), False)])
+        self.assertRaisesRegex(
+            SparkConnectAnalysisException,
+            "NULLABLE_COLUMN_OR_FIELD",
+            lambda: cdf.to(schema).toPandas(),
+        )
 
-        self.assertEqual(cdf.schema, df.schema)
-        self.assert_eq(cdf.toPandas(), df.toPandas())
+        # field cannot upcast
+        schema = StructType([StructField("name", LongType())])
+        self.assertRaisesRegex(
+            SparkConnectAnalysisException,
+            "INVALID_COLUMN_OR_FIELD_DATA_TYPE",
+            lambda: cdf.to(schema).toPandas(),
+        )
 
-        # Change the column data type failed
         schema = StructType(
             [
                 StructField("id", IntegerType(), True),
                 StructField("name", IntegerType(), True),
             ]
         )
-
-        with self.assertRaises(SparkConnectException) as context:
-            self.connect.read.table(self.tbl_name).to(schema).toPandas()
-            self.assertIn(
-                """Column or field `name` is of type "STRING" while it's 
required to be "INT".""",
-                str(context.exception),
-            )
+        self.assertRaisesRegex(
+            SparkConnectAnalysisException,
+            "INVALID_COLUMN_OR_FIELD_DATA_TYPE",
+            lambda: cdf.to(schema).toPandas(),
+        )
 
         # Test map type and array type
         schema = StructType(
@@ -744,11 +772,10 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
                 StructField("my_array", ArrayType(IntegerType(), False), True),
             ]
         )
-        cdf = self.connect.read.table(self.tbl_name4).to(schema)
-        df = self.spark.read.table(self.tbl_name4).to(schema)
+        cdf = self.connect.read.table(self.tbl_name4)
+        df = self.spark.read.table(self.tbl_name4)
 
-        self.assertEqual(cdf.schema, df.schema)
-        self.assert_eq(cdf.toPandas(), df.toPandas())
+        assert_eq_schema(cdf, df, schema)
 
     def test_toDF(self):
         # SPARK-41310: test DataFrame.toDF()
@@ -1195,21 +1222,19 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
     def test_with_columns(self):
         # SPARK-41256: test withColumn(s).
         self.assert_eq(
-            self.connect.read.table(self.tbl_name).withColumn("id", 
lit(False)).toPandas(),
-            self.spark.read.table(self.tbl_name)
-            .withColumn("id", pyspark.sql.functions.lit(False))
-            .toPandas(),
+            self.connect.read.table(self.tbl_name).withColumn("id", 
CF.lit(False)).toPandas(),
+            self.spark.read.table(self.tbl_name).withColumn("id", 
SF.lit(False)).toPandas(),
         )
 
         self.assert_eq(
             self.connect.read.table(self.tbl_name)
-            .withColumns({"id": lit(False), "col_not_exist": lit(False)})
+            .withColumns({"id": CF.lit(False), "col_not_exist": CF.lit(False)})
             .toPandas(),
             self.spark.read.table(self.tbl_name)
             .withColumns(
                 {
-                    "id": pyspark.sql.functions.lit(False),
-                    "col_not_exist": pyspark.sql.functions.lit(False),
+                    "id": SF.lit(False),
+                    "col_not_exist": SF.lit(False),
                 }
             )
             .toPandas(),
@@ -1392,9 +1417,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
     def test_stat_sample_by(self):
         # SPARK-41069: Test stat.sample_by
 
-        from pyspark.sql import functions as SF
-        from pyspark.sql.connect import functions as CF
-
         cdf = self.connect.range(0, 100).select((CF.col("id") % 
3).alias("key"))
         sdf = self.spark.range(0, 100).select((SF.col("id") % 3).alias("key"))
 
@@ -1475,7 +1497,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         """SPARK-41203: Support DF.transform"""
 
         def transform_df(input_df: CDataFrame) -> CDataFrame:
-            return input_df.select((col("id") + lit(10)).alias("id"))
+            return input_df.select((CF.col("id") + CF.lit(10)).alias("id"))
 
         df = self.connect.range(1, 100)
         result_left = df.transform(transform_df).collect()
@@ -1490,13 +1512,13 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         """Testing supported and unsupported alias"""
         col0 = (
             self.connect.range(1, 10)
-            .select(col("id").alias("name", metadata={"max": 99}))
+            .select(CF.col("id").alias("name", metadata={"max": 99}))
             .schema.names[0]
         )
         self.assertEqual("name", col0)
 
         with self.assertRaises(SparkConnectException) as exc:
-            self.connect.range(1, 10).select(col("id").alias("this", "is", 
"not")).collect()
+            self.connect.range(1, 10).select(CF.col("id").alias("this", "is", 
"not")).collect()
         self.assertIn("(this, is, not)", str(exc.exception))
 
     def test_column_regexp(self) -> None:
@@ -1555,7 +1577,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         # SPARK-41325: groupby.avg()
         df = (
             self.connect.range(10)
-            .groupBy((col("id") % lit(2)).alias("moded"))
+            .groupBy((CF.col("id") % CF.lit(2)).alias("moded"))
             .avg("id")
             .sort("moded")
         )
@@ -1565,13 +1587,11 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         self.assertEqual(5.0, res[1][1])
 
         # Additional GroupBy tests with 3 rows
-        import pyspark.sql.connect.functions as CF
-        import pyspark.sql.functions as PF
 
-        df_a = self.connect.range(10).groupBy((col("id") % 
lit(3)).alias("moded"))
-        df_b = self.spark.range(10).groupBy((PF.col("id") % 
PF.lit(3)).alias("moded"))
+        df_a = self.connect.range(10).groupBy((CF.col("id") % 
CF.lit(3)).alias("moded"))
+        df_b = self.spark.range(10).groupBy((SF.col("id") % 
SF.lit(3)).alias("moded"))
         self.assertEqual(
-            set(df_b.agg(PF.sum("id")).collect()), 
set(df_a.agg(CF.sum("id")).collect())
+            set(df_b.agg(SF.sum("id")).collect()), 
set(df_a.agg(CF.sum("id")).collect())
         )
 
         # Dict agg
@@ -1603,8 +1623,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
         )
 
     def test_grouped_data(self):
-        from pyspark.sql import functions as SF
-        from pyspark.sql.connect import functions as CF
 
         query = """
             SELECT * FROM VALUES


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to