This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 91b95056806 [SPARK-39823][SQL][PYTHON] Rename Dataset.as as Dataset.to 
and add DataFrame.to in PySpark
91b95056806 is described below

commit 91b950568066830ecd7a4581ab5bf4dbdbbeb474
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Wed Jul 27 08:11:18 2022 +0800

    [SPARK-39823][SQL][PYTHON] Rename Dataset.as as Dataset.to and add 
DataFrame.to in PySpark
    
    ### What changes were proposed in this pull request?
    
    1, rename `Dataset.as(StructType)` to `Dataset.to(StructType)`, since `as` 
is a keyword in python, we dont want to use a different name;
    2, Add `DataFrame.to(StructType)` in Python
    
    ### Why are the changes needed?
    for function parity
    
    ### Does this PR introduce _any_ user-facing change?
    yes, new api
    
    ### How was this patch tested?
    added UT
    
    Closes #37233 from zhengruifeng/py_ds_as.
    
    Authored-by: Ruifeng Zheng <ruife...@apache.org>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../source/reference/pyspark.sql/dataframe.rst     |  1 +
 python/pyspark/sql/dataframe.py                    | 50 +++++++++++++++++++++
 python/pyspark/sql/tests/test_dataframe.py         | 36 ++++++++++++++-
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  4 +-
 ...emaSuite.scala => DataFrameToSchemaSuite.scala} | 52 +++++++++++-----------
 5 files changed, 114 insertions(+), 29 deletions(-)

diff --git a/python/docs/source/reference/pyspark.sql/dataframe.rst 
b/python/docs/source/reference/pyspark.sql/dataframe.rst
index 5b6e704ba48..8cf083e5dd4 100644
--- a/python/docs/source/reference/pyspark.sql/dataframe.rst
+++ b/python/docs/source/reference/pyspark.sql/dataframe.rst
@@ -102,6 +102,7 @@ DataFrame
     DataFrame.summary
     DataFrame.tail
     DataFrame.take
+    DataFrame.to
     DataFrame.toDF
     DataFrame.toJSON
     DataFrame.toLocalIterator
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index efebd05c08d..481dafa310d 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1422,6 +1422,56 @@ class DataFrame(PandasMapOpsMixin, 
PandasConversionMixin):
         jc = self._jdf.colRegex(colName)
         return Column(jc)
 
+    def to(self, schema: StructType) -> "DataFrame":
+        """
+        Returns a new :class:`DataFrame` where each row is reconciled to match 
the specified
+        schema.
+
+        Notes
+        -----
+        1, Reorder columns and/or inner fields by name to match the specified 
schema.
+
+        2, Project away columns and/or inner fields that are not needed by the 
specified schema.
+        Missing columns and/or inner fields (present in the specified schema 
but not input
+        DataFrame) lead to failures.
+
+        3, Cast the columns and/or inner fields to match the data types in the 
specified schema,
+        if the types are compatible, e.g., numeric to numeric (error if 
overflows), but not string
+        to int.
+
+        4, Carry over the metadata from the specified schema, while the 
columns and/or inner fields
+        still keep their own metadata if not overwritten by the specified 
schema.
+
+        5, Fail if the nullability is not compatible. For example, the column 
and/or inner field
+        is nullable but the specified schema requires them to be not nullable.
+
+        .. versionadded:: 3.4.0
+
+        Parameters
+        ----------
+        schema : :class:`StructType`
+            Specified schema.
+
+        Examples
+        --------
+        >>> df = spark.createDataFrame([("a", 1)], ["i", "j"])
+        >>> df.schema
+        StructType([StructField('i', StringType(), True), StructField('j', 
LongType(), True)])
+        >>> schema = StructType([StructField("j", StringType()), 
StructField("i", StringType())])
+        >>> df2 = df.to(schema)
+        >>> df2.schema
+        StructType([StructField('j', StringType(), True), StructField('i', 
StringType(), True)])
+        >>> df2.show()
+        +---+---+
+        |  j|  i|
+        +---+---+
+        |  1|  a|
+        +---+---+
+        """
+        assert schema is not None
+        jschema = self._jdf.sparkSession().parseDataType(schema.json())
+        return DataFrame(self._jdf.to(jschema), self.sparkSession)
+
     def alias(self, alias: str) -> "DataFrame":
         """Returns a new :class:`DataFrame` with an alias set.
 
diff --git a/python/pyspark/sql/tests/test_dataframe.py 
b/python/pyspark/sql/tests/test_dataframe.py
index ac6b6f68aed..7c7d3d1e51c 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -25,11 +25,12 @@ import unittest
 from typing import cast
 
 from pyspark.sql import SparkSession, Row
-from pyspark.sql.functions import col, lit, count, sum, mean
+from pyspark.sql.functions import col, lit, count, sum, mean, struct
 from pyspark.sql.types import (
     StringType,
     IntegerType,
     DoubleType,
+    LongType,
     StructType,
     StructField,
     BooleanType,
@@ -1200,6 +1201,39 @@ class DataFrameTests(ReusedSQLTestCase):
             [Row(value=None)],
         )
 
+    def test_to(self):
+        schema = StructType(
+            [StructField("i", StringType(), True), StructField("j", 
IntegerType(), True)]
+        )
+        df = self.spark.createDataFrame([("a", 1)], schema)
+
+        schema1 = StructType([StructField("j", StringType()), StructField("i", 
StringType())])
+        df1 = df.to(schema1)
+        self.assertEqual(schema1, df1.schema)
+        self.assertEqual(df.count(), df1.count())
+
+        schema2 = StructType([StructField("j", LongType())])
+        df2 = df.to(schema2)
+        self.assertEqual(schema2, df2.schema)
+        self.assertEqual(df.count(), df2.count())
+
+        schema3 = StructType([StructField("struct", schema1, False)])
+        df3 = df.select(struct("i", "j").alias("struct")).to(schema3)
+        self.assertEqual(schema3, df3.schema)
+        self.assertEqual(df.count(), df3.count())
+
+        # incompatible field nullability
+        schema4 = StructType([StructField("j", LongType(), False)])
+        self.assertRaisesRegex(
+            AnalysisException, "NULLABLE_COLUMN_OR_FIELD", lambda: 
df.to(schema4)
+        )
+
+        # field cannot upcast
+        schema5 = StructType([StructField("i", LongType())])
+        self.assertRaisesRegex(
+            AnalysisException, "INVALID_COLUMN_OR_FIELD_DATA_TYPE", lambda: 
df.to(schema5)
+        )
+
 
 class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils):
     # These tests are separate because it uses 
'spark.sql.queryExecutionListeners' which is
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 49b4a8389f9..2e1dc7d83d2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -476,14 +476,14 @@ class Dataset[T] private[sql](
    *   int.</li>
    *   <li>Carry over the metadata from the specified schema, while the 
columns and/or inner fields
    *   still keep their own metadata if not overwritten by the specified 
schema.</li>
-   *   <li>Fail if the nullability are not compatible. For example, the column 
and/or inner field is
+   *   <li>Fail if the nullability is not compatible. For example, the column 
and/or inner field is
    *   nullable but the specified schema requires them to be not nullable.</li>
    * </ul>
    *
    * @group basic
    * @since 3.4.0
    */
-  def as(schema: StructType): DataFrame = withPlan {
+  def to(schema: StructType): DataFrame = withPlan {
     Project.matchSchema(logicalPlan, schema, sparkSession.sessionState.conf)
   }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsSchemaSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameToSchemaSuite.scala
similarity index 93%
rename from 
sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsSchemaSuite.scala
rename to 
sql/core/src/test/scala/org/apache/spark/sql/DataFrameToSchemaSuite.scala
index eccbfc339f0..26ddbc4569e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAsSchemaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameToSchemaSuite.scala
@@ -22,33 +22,33 @@ import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types._
 
-class DataFrameAsSchemaSuite extends QueryTest with SharedSparkSession {
+class DataFrameToSchemaSuite extends QueryTest with SharedSparkSession {
   import testImplicits._
 
   test("reorder columns by name") {
     val schema = new StructType().add("j", StringType).add("i", StringType)
-    val df = Seq("a" -> "b").toDF("i", "j").as(schema)
+    val df = Seq("a" -> "b").toDF("i", "j").to(schema)
     assert(df.schema == schema)
     checkAnswer(df, Row("b", "a"))
   }
 
   test("case insensitive: reorder columns by name") {
     val schema = new StructType().add("J", StringType).add("I", StringType)
-    val df = Seq("a" -> "b").toDF("i", "j").as(schema)
+    val df = Seq("a" -> "b").toDF("i", "j").to(schema)
     assert(df.schema == schema)
     checkAnswer(df, Row("b", "a"))
   }
 
   test("select part of the columns") {
     val schema = new StructType().add("j", StringType)
-    val df = Seq("a" -> "b").toDF("i", "j").as(schema)
+    val df = Seq("a" -> "b").toDF("i", "j").to(schema)
     assert(df.schema == schema)
     checkAnswer(df, Row("b"))
   }
 
   test("negative: column not found") {
     val schema = new StructType().add("non_exist", StringType)
-    val e = intercept[SparkThrowable](Seq("a" -> "b").toDF("i", 
"j").as(schema))
+    val e = intercept[SparkThrowable](Seq("a" -> "b").toDF("i", 
"j").to(schema))
     checkError(
       exception = e,
       errorClass = "UNRESOLVED_COLUMN",
@@ -59,7 +59,7 @@ class DataFrameAsSchemaSuite extends QueryTest with 
SharedSparkSession {
 
   test("negative: ambiguous column") {
     val schema = new StructType().add("i", StringType)
-    val e = intercept[SparkThrowable](Seq("a" -> "b").toDF("i", 
"I").as(schema))
+    val e = intercept[SparkThrowable](Seq("a" -> "b").toDF("i", 
"I").to(schema))
     checkError(
       exception = e,
       errorClass = "AMBIGUOUS_COLUMN_OR_FIELD",
@@ -72,7 +72,7 @@ class DataFrameAsSchemaSuite extends QueryTest with 
SharedSparkSession {
     val schema = new StructType().add("j", IntegerType)
     val data = Seq("a" -> 1).toDF("i", "j")
     assert(!data.schema.fields(1).nullable)
-    val df = data.as(schema)
+    val df = data.to(schema)
     val finalSchema = new StructType().add("j", IntegerType, nullable = false)
     assert(df.schema == finalSchema)
     checkAnswer(df, Row(1))
@@ -82,7 +82,7 @@ class DataFrameAsSchemaSuite extends QueryTest with 
SharedSparkSession {
     val schema = new StructType().add("i", IntegerType, nullable = false)
     val data = sql("SELECT i FROM VALUES 1, NULL as t(i)")
     assert(data.schema.fields(0).nullable)
-    val e = intercept[SparkThrowable](data.as(schema))
+    val e = intercept[SparkThrowable](data.to(schema))
     checkError(
       exception = e,
       errorClass = "NULLABLE_COLUMN_OR_FIELD",
@@ -91,14 +91,14 @@ class DataFrameAsSchemaSuite extends QueryTest with 
SharedSparkSession {
 
   test("upcast the original column") {
     val schema = new StructType().add("j", LongType, nullable = false)
-    val df = Seq("a" -> 1).toDF("i", "j").as(schema)
+    val df = Seq("a" -> 1).toDF("i", "j").to(schema)
     assert(df.schema == schema)
     checkAnswer(df, Row(1L))
   }
 
   test("negative: column cannot upcast") {
     val schema = new StructType().add("i", IntegerType)
-    val e = intercept[SparkThrowable](Seq("a" -> 1).toDF("i", "j").as(schema))
+    val e = intercept[SparkThrowable](Seq("a" -> 1).toDF("i", "j").to(schema))
     checkError(
       exception = e,
       errorClass = "INVALID_COLUMN_OR_FIELD_DATA_TYPE",
@@ -113,7 +113,7 @@ class DataFrameAsSchemaSuite extends QueryTest with 
SharedSparkSession {
     val metadata1 = new MetadataBuilder().putString("a", "1").putString("b", 
"2").build()
     val metadata2 = new MetadataBuilder().putString("b", "3").putString("c", 
"4").build()
     val schema = new StructType().add("i", IntegerType, nullable = true, 
metadata = metadata2)
-    val df = Seq((1)).toDF("i").select($"i".as("i", metadata1)).as(schema)
+    val df = Seq((1)).toDF("i").select($"i".as("i", metadata1)).to(schema)
     // Metadata "a" remains, "b" gets overwritten by the specified schema, "c" 
is newly added.
     val resultMetadata = new MetadataBuilder()
       .putString("a", "1").putString("b", "3").putString("c", "4").build()
@@ -124,7 +124,7 @@ class DataFrameAsSchemaSuite extends QueryTest with 
SharedSparkSession {
   test("reorder inner fields by name") {
     val innerFields = new StructType().add("j", StringType).add("i", 
StringType)
     val schema = new StructType().add("struct", innerFields, nullable = false)
-    val df = Seq("a" -> "b").toDF("i", "j").select(struct($"i", 
$"j").as("struct")).as(schema)
+    val df = Seq("a" -> "b").toDF("i", "j").select(struct($"i", 
$"j").as("struct")).to(schema)
     assert(df.schema == schema)
     checkAnswer(df, Row(Row("b", "a")))
   }
@@ -132,7 +132,7 @@ class DataFrameAsSchemaSuite extends QueryTest with 
SharedSparkSession {
   test("case insensitive: reorder inner fields by name") {
     val innerFields = new StructType().add("J", StringType).add("I", 
StringType)
     val schema = new StructType().add("struct", innerFields, nullable = false)
-    val df = Seq("a" -> "b").toDF("i", "j").select(struct($"i", 
$"j").as("struct")).as(schema)
+    val df = Seq("a" -> "b").toDF("i", "j").select(struct($"i", 
$"j").as("struct")).to(schema)
     assert(df.schema == schema)
     checkAnswer(df, Row(Row("b", "a")))
   }
@@ -141,7 +141,7 @@ class DataFrameAsSchemaSuite extends QueryTest with 
SharedSparkSession {
     val innerFields = new StructType().add("non_exist", StringType)
     val schema = new StructType().add("struct", innerFields, nullable = false)
     val e = intercept[SparkThrowable] {
-      Seq("a" -> "b").toDF("i", "j").select(struct($"i", 
$"j").as("struct")).as(schema)
+      Seq("a" -> "b").toDF("i", "j").select(struct($"i", 
$"j").as("struct")).to(schema)
     }
     checkError(
       exception = e,
@@ -158,7 +158,7 @@ class DataFrameAsSchemaSuite extends QueryTest with 
SharedSparkSession {
     val data = Seq("a" -> 1).toDF("i", "j").select(struct($"i", 
$"j").as("struct"))
     assert(!data.schema.fields(0).nullable)
     
assert(!data.schema.fields(0).dataType.asInstanceOf[StructType].fields(1).nullable)
-    val df = data.as(schema)
+    val df = data.to(schema)
     val finalFields = new StructType().add("j", IntegerType, nullable = false)
     val finalSchema = new StructType().add("struct", finalFields, nullable = 
false)
     assert(df.schema == finalSchema)
@@ -171,7 +171,7 @@ class DataFrameAsSchemaSuite extends QueryTest with 
SharedSparkSession {
     val data = sql("SELECT i FROM VALUES 1, NULL as 
t(i)").select(struct($"i").as("struct"))
     assert(!data.schema.fields(0).nullable)
     
assert(data.schema.fields(0).dataType.asInstanceOf[StructType].fields(0).nullable)
-    val e = intercept[SparkThrowable](data.as(schema))
+    val e = intercept[SparkThrowable](data.to(schema))
     checkError(
       exception = e,
       errorClass = "NULLABLE_COLUMN_OR_FIELD",
@@ -181,7 +181,7 @@ class DataFrameAsSchemaSuite extends QueryTest with 
SharedSparkSession {
   test("upcast the original field") {
     val innerFields = new StructType().add("j", LongType, nullable = false)
     val schema = new StructType().add("struct", innerFields, nullable = false)
-    val df = Seq("a" -> 1).toDF("i", "j").select(struct($"i", 
$"j").as("struct")).as(schema)
+    val df = Seq("a" -> 1).toDF("i", "j").select(struct($"i", 
$"j").as("struct")).to(schema)
     assert(df.schema == schema)
     checkAnswer(df, Row(Row(1L)))
   }
@@ -190,7 +190,7 @@ class DataFrameAsSchemaSuite extends QueryTest with 
SharedSparkSession {
     val innerFields = new StructType().add("i", IntegerType)
     val schema = new StructType().add("struct", innerFields, nullable = false)
     val e = intercept[SparkThrowable] {
-      Seq("a" -> 1).toDF("i", "j").select(struct($"i", 
$"j").as("struct")).as(schema)
+      Seq("a" -> 1).toDF("i", "j").select(struct($"i", 
$"j").as("struct")).to(schema)
     }
     checkError(
       exception = e,
@@ -210,7 +210,7 @@ class DataFrameAsSchemaSuite extends QueryTest with 
SharedSparkSession {
     val df = Seq((1)).toDF("i")
       .select($"i".as("i", metadata1))
       .select(struct($"i").as("struct"))
-      .as(schema)
+      .to(schema)
     // Metadata "a" remains, "b" gets overwritten by the specified schema, "c" 
is newly added.
     val resultMetadata = new MetadataBuilder()
       .putString("a", "1").putString("b", "3").putString("c", "4").build()
@@ -223,7 +223,7 @@ class DataFrameAsSchemaSuite extends QueryTest with 
SharedSparkSession {
     val schema = new StructType().add("arr", arr, nullable = false)
     val df = Seq("a" -> "b").toDF("i", "j")
       .select(array(struct($"i", $"j")).as("arr"))
-      .as(schema)
+      .to(schema)
     assert(df.schema == schema)
     checkAnswer(df, Row(Seq(Row("b", "a"))))
   }
@@ -234,7 +234,7 @@ class DataFrameAsSchemaSuite extends QueryTest with 
SharedSparkSession {
     val schema = new StructType().add("arr", arr, nullable = false)
     val df = Seq("a" -> 1).toDF("i", "j")
       .select(array(struct($"i", $"j")).as("arr"))
-      .as(schema)
+      .to(schema)
     assert(df.schema == schema)
     checkAnswer(df, Row(Seq(Row(1L))))
   }
@@ -244,7 +244,7 @@ class DataFrameAsSchemaSuite extends QueryTest with 
SharedSparkSession {
     val schema = new StructType().add("arr", arr)
     val data = sql("SELECT i FROM VALUES 1, NULL as 
t(i)").select(array($"i").as("arr"))
     assert(data.schema.fields(0).dataType.asInstanceOf[ArrayType].containsNull)
-    val e = intercept[SparkThrowable](data.as(schema))
+    val e = intercept[SparkThrowable](data.to(schema))
     checkError(
       exception = e,
       errorClass = "NULLABLE_ARRAY_OR_MAP_ELEMENT",
@@ -260,7 +260,7 @@ class DataFrameAsSchemaSuite extends QueryTest with 
SharedSparkSession {
     val df = Seq((1)).toDF("i")
       .select($"i")
       .select(array(struct($"i")).as("arr", metadata1))
-      .as(schema)
+      .to(schema)
     // Metadata "a" remains, "b" gets overwritten by the specified schema, "c" 
is newly added.
     val resultMetadata = new MetadataBuilder()
       .putString("a", "1").putString("b", "3").putString("c", "4").build()
@@ -276,7 +276,7 @@ class DataFrameAsSchemaSuite extends QueryTest with 
SharedSparkSession {
     val df = Seq((1)).toDF("i")
       .select($"i".as("i", metadata1))
       .select(array(struct($"i")).as("arr"))
-      .as(schema)
+      .to(schema)
     // Metadata "a" remains, "b" gets overwritten by the specified schema, "c" 
is newly added.
     val resultMetadata = new MetadataBuilder()
       .putString("a", "1").putString("b", "3").putString("c", "4").build()
@@ -290,7 +290,7 @@ class DataFrameAsSchemaSuite extends QueryTest with 
SharedSparkSession {
     val schema = new StructType().add("map", m, nullable = false)
     val df = Seq("a" -> "b").toDF("i", "j")
       .select(map(struct($"i", $"j"), $"i").as("map"))
-      .as(schema)
+      .to(schema)
     assert(df.schema == schema)
     checkAnswer(df, Row(Map(Row("b", "a") -> "a")))
   }
@@ -301,7 +301,7 @@ class DataFrameAsSchemaSuite extends QueryTest with 
SharedSparkSession {
     val schema = new StructType().add("map", m, nullable = false)
     val df = Seq("a" -> "b").toDF("i", "j")
       .select(map($"i", struct($"i", $"j")).as("map"))
-      .as(schema)
+      .to(schema)
     assert(df.schema == schema)
     checkAnswer(df, Row(Map("a" -> Row("b", "a"))))
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to