This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 0469479c751 [SPARK-41818][SPARK-42000][CONNECT] Fix saveAsTable to 
find the default source
0469479c751 is described below

commit 0469479c75174acf9873c819627b74c8286cee6a
Author: Takuya UESHIN <ues...@databricks.com>
AuthorDate: Tue Feb 14 09:05:57 2023 +0900

    [SPARK-41818][SPARK-42000][CONNECT] Fix saveAsTable to find the default 
source
    
    ### What changes were proposed in this pull request?
    
    Fixes `DataFrameWriter.saveAsTable` to find the default source.
    
    ### Why are the changes needed?
    
    Currently `DataFrameWriter.saveAsTable` fails when `format` is not 
specified because protobuf defines `source` as required and it will be an empty 
string instead of `null`, then `DataFrameWriter` tries to find the data source 
`""`.
    
    The `source` field should be optional to let Spark decide the default 
source.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Users can call `saveAsTable` without `format`.
    
    ### How was this patch tested?
    
    Enabled related tests.
    
    Closes #40000 from ueshin/issues/SPARK-42000/saveAsTable.
    
    Authored-by: Takuya UESHIN <ues...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
    (cherry picked from commit c1f242eb4c514d7fba8e0c47c96e40cba82a39ad)
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../src/main/protobuf/spark/connect/commands.proto |  4 +--
 .../sql/connect/planner/SparkConnectPlanner.scala  |  2 +-
 python/pyspark/sql/connect/proto/commands_pb2.py   | 32 +++++++++++-----------
 python/pyspark/sql/connect/proto/commands_pb2.pyi  | 15 ++++++++--
 python/pyspark/sql/connect/readwriter.py           |  5 ++--
 python/pyspark/sql/readwriter.py                   | 18 ++++++------
 .../sql/tests/connect/test_parity_readwriter.py    |  5 ----
 7 files changed, 44 insertions(+), 37 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
index 73218697577..8872dc626a9 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
@@ -66,8 +66,8 @@ message WriteOperation {
   // (Required) The output of the `input` relation will be persisted according 
to the options.
   Relation input = 1;
 
-  // (Required) Format value according to the Spark documentation. Examples 
are: text, parquet, delta.
-  string source = 2;
+  // (Optional) Format value according to the Spark documentation. Examples 
are: text, parquet, delta.
+  optional string source = 2;
 
   // The destination of the write operation must be either a path or a table.
   oneof save_type {
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 53d494cdcb7..d509a926cdd 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1539,7 +1539,7 @@ class SparkConnectPlanner(val session: SparkSession) {
       w.partitionBy(names.toSeq: _*)
     }
 
-    if (writeOperation.getSource != null) {
+    if (writeOperation.hasSource) {
       w.format(writeOperation.getSource)
     }
 
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py 
b/python/pyspark/sql/connect/proto/commands_pb2.py
index f7e9260212e..a4b7fe268ce 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.py
+++ b/python/pyspark/sql/connect/proto/commands_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xab\x03\n\x07\x43ommand\x12]\n\x11register_function\x18\x01
 
\x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02
 
\x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x1
 [...]
+    
b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xab\x03\n\x07\x43ommand\x12]\n\x11register_function\x18\x01
 
\x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02
 
\x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x1
 [...]
 )
 
 
@@ -151,19 +151,19 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _CREATEDATAFRAMEVIEWCOMMAND._serialized_start = 596
     _CREATEDATAFRAMEVIEWCOMMAND._serialized_end = 746
     _WRITEOPERATION._serialized_start = 749
-    _WRITEOPERATION._serialized_end = 1491
-    _WRITEOPERATION_OPTIONSENTRY._serialized_start = 1187
-    _WRITEOPERATION_OPTIONSENTRY._serialized_end = 1245
-    _WRITEOPERATION_BUCKETBY._serialized_start = 1247
-    _WRITEOPERATION_BUCKETBY._serialized_end = 1338
-    _WRITEOPERATION_SAVEMODE._serialized_start = 1341
-    _WRITEOPERATION_SAVEMODE._serialized_end = 1478
-    _WRITEOPERATIONV2._serialized_start = 1494
-    _WRITEOPERATIONV2._serialized_end = 2289
-    _WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 1187
-    _WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 1245
-    _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 2061
-    _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2127
-    _WRITEOPERATIONV2_MODE._serialized_start = 2130
-    _WRITEOPERATIONV2_MODE._serialized_end = 2289
+    _WRITEOPERATION._serialized_end = 1507
+    _WRITEOPERATION_OPTIONSENTRY._serialized_start = 1192
+    _WRITEOPERATION_OPTIONSENTRY._serialized_end = 1250
+    _WRITEOPERATION_BUCKETBY._serialized_start = 1252
+    _WRITEOPERATION_BUCKETBY._serialized_end = 1343
+    _WRITEOPERATION_SAVEMODE._serialized_start = 1346
+    _WRITEOPERATION_SAVEMODE._serialized_end = 1483
+    _WRITEOPERATIONV2._serialized_start = 1510
+    _WRITEOPERATIONV2._serialized_end = 2305
+    _WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 1192
+    _WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 1250
+    _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 2077
+    _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2143
+    _WRITEOPERATIONV2_MODE._serialized_start = 2146
+    _WRITEOPERATIONV2_MODE._serialized_end = 2305
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.pyi 
b/python/pyspark/sql/connect/proto/commands_pb2.pyi
index 4bdf1f1ed4e..b8daec597d0 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/commands_pb2.pyi
@@ -259,7 +259,7 @@ class WriteOperation(google.protobuf.message.Message):
     def input(self) -> pyspark.sql.connect.proto.relations_pb2.Relation:
         """(Required) The output of the `input` relation will be persisted 
according to the options."""
     source: builtins.str
-    """(Required) Format value according to the Spark documentation. Examples 
are: text, parquet, delta."""
+    """(Optional) Format value according to the Spark documentation. Examples 
are: text, parquet, delta."""
     path: builtins.str
     table_name: builtins.str
     mode: global___WriteOperation.SaveMode.ValueType
@@ -286,7 +286,7 @@ class WriteOperation(google.protobuf.message.Message):
         self,
         *,
         input: pyspark.sql.connect.proto.relations_pb2.Relation | None = ...,
-        source: builtins.str = ...,
+        source: builtins.str | None = ...,
         path: builtins.str = ...,
         table_name: builtins.str = ...,
         mode: global___WriteOperation.SaveMode.ValueType = ...,
@@ -298,6 +298,8 @@ class WriteOperation(google.protobuf.message.Message):
     def HasField(
         self,
         field_name: typing_extensions.Literal[
+            "_source",
+            b"_source",
             "bucket_by",
             b"bucket_by",
             "input",
@@ -306,6 +308,8 @@ class WriteOperation(google.protobuf.message.Message):
             b"path",
             "save_type",
             b"save_type",
+            "source",
+            b"source",
             "table_name",
             b"table_name",
         ],
@@ -313,6 +317,8 @@ class WriteOperation(google.protobuf.message.Message):
     def ClearField(
         self,
         field_name: typing_extensions.Literal[
+            "_source",
+            b"_source",
             "bucket_by",
             b"bucket_by",
             "input",
@@ -335,6 +341,11 @@ class WriteOperation(google.protobuf.message.Message):
             b"table_name",
         ],
     ) -> None: ...
+    @typing.overload
+    def WhichOneof(
+        self, oneof_group: typing_extensions.Literal["_source", b"_source"]
+    ) -> typing_extensions.Literal["source"] | None: ...
+    @typing.overload
     def WhichOneof(
         self, oneof_group: typing_extensions.Literal["save_type", b"save_type"]
     ) -> typing_extensions.Literal["path", "table_name"] | None: ...
diff --git a/python/pyspark/sql/connect/readwriter.py 
b/python/pyspark/sql/connect/readwriter.py
index 8724348592e..ee4e3018bb9 100644
--- a/python/pyspark/sql/connect/readwriter.py
+++ b/python/pyspark/sql/connect/readwriter.py
@@ -472,6 +472,8 @@ class DataFrameWriter(OptionUtils):
     def insertInto(self, tableName: str, overwrite: Optional[bool] = None) -> 
None:
         if overwrite is not None:
             self.mode("overwrite" if overwrite else "append")
+        elif self._write.mode is None or self._write.mode != "overwrite":
+            self.mode("append")
         self.saveAsTable(tableName)
 
     insertInto.__doc__ = PySparkDataFrameWriter.insertInto.__doc__
@@ -695,9 +697,8 @@ def _test() -> None:
     del pyspark.sql.connect.readwriter.DataFrameWriter.bucketBy.__doc__
     del pyspark.sql.connect.readwriter.DataFrameWriter.sortBy.__doc__
 
-    # TODO(SPARK-41818): Support saveAsTable
+    # TODO(SPARK-42426): insertInto fails when the column names are different 
from the table columns
     del pyspark.sql.connect.readwriter.DataFrameWriter.insertInto.__doc__
-    del pyspark.sql.connect.readwriter.DataFrameWriter.saveAsTable.__doc__
 
     globs["spark"] = (
         PySparkSession.builder.appName("sql.connect.readwriter tests")
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index b87fb6528bb..19ff342ac3b 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -468,7 +468,7 @@ class DataFrameReader(OptionUtils):
         |  8|
         |  9|
         +---+
-        >>> _ = spark.sql("DROP TABLE tblA")
+        >>> _ = spark.sql("DROP TABLE tblA").collect()
         """
         return self._df(self._jreader.table(tableName))
 
@@ -1234,7 +1234,7 @@ class DataFrameWriter(OptionUtils):
 
         >>> from pyspark.sql.functions import input_file_name
         >>> # Write a DataFrame into a Parquet file in a bucketed manner.
-        ... _ = spark.sql("DROP TABLE IF EXISTS bucketed_table")
+        ... _ = spark.sql("DROP TABLE IF EXISTS bucketed_table").collect()
         >>> spark.createDataFrame([
         ...     (100, "Hyukjin Kwon"), (120, "Hyukjin Kwon"), (140, "Haejoon 
Lee")],
         ...     schema=["age", "name"]
@@ -1248,7 +1248,7 @@ class DataFrameWriter(OptionUtils):
         |120|Hyukjin Kwon|
         |140| Haejoon Lee|
         +---+------------+
-        >>> _ = spark.sql("DROP TABLE bucketed_table")
+        >>> _ = spark.sql("DROP TABLE bucketed_table").collect()
         """
         if not isinstance(numBuckets, int):
             raise TypeError("numBuckets should be an int, got 
{0}.".format(type(numBuckets)))
@@ -1298,7 +1298,7 @@ class DataFrameWriter(OptionUtils):
 
         >>> from pyspark.sql.functions import input_file_name
         >>> # Write a DataFrame into a Parquet file in a sorted-bucketed 
manner.
-        ... _ = spark.sql("DROP TABLE IF EXISTS sorted_bucketed_table")
+        ... _ = spark.sql("DROP TABLE IF EXISTS 
sorted_bucketed_table").collect()
         >>> spark.createDataFrame([
         ...     (100, "Hyukjin Kwon"), (120, "Hyukjin Kwon"), (140, "Haejoon 
Lee")],
         ...     schema=["age", "name"]
@@ -1313,7 +1313,7 @@ class DataFrameWriter(OptionUtils):
         |120|Hyukjin Kwon|
         |140| Haejoon Lee|
         +---+------------+
-        >>> _ = spark.sql("DROP TABLE sorted_bucketed_table")
+        >>> _ = spark.sql("DROP TABLE sorted_bucketed_table").collect()
         """
         if isinstance(col, (list, tuple)):
             if cols:
@@ -1419,7 +1419,7 @@ class DataFrameWriter(OptionUtils):
 
         Examples
         --------
-        >>> _ = spark.sql("DROP TABLE IF EXISTS tblA")
+        >>> _ = spark.sql("DROP TABLE IF EXISTS tblA").collect()
         >>> df = spark.createDataFrame([
         ...     (100, "Hyukjin Kwon"), (120, "Hyukjin Kwon"), (140, "Haejoon 
Lee")],
         ...     schema=["age", "name"]
@@ -1440,7 +1440,7 @@ class DataFrameWriter(OptionUtils):
         |140| Haejoon Lee|
         |140| Haejoon Lee|
         +---+------------+
-        >>> _ = spark.sql("DROP TABLE tblA")
+        >>> _ = spark.sql("DROP TABLE tblA").collect()
         """
         if overwrite is not None:
             self.mode("overwrite" if overwrite else "append")
@@ -1497,7 +1497,7 @@ class DataFrameWriter(OptionUtils):
         --------
         Creates a table from a DataFrame, and read it back.
 
-        >>> _ = spark.sql("DROP TABLE IF EXISTS tblA")
+        >>> _ = spark.sql("DROP TABLE IF EXISTS tblA").collect()
         >>> spark.createDataFrame([
         ...     (100, "Hyukjin Kwon"), (120, "Hyukjin Kwon"), (140, "Haejoon 
Lee")],
         ...     schema=["age", "name"]
@@ -1510,7 +1510,7 @@ class DataFrameWriter(OptionUtils):
         |120|Hyukjin Kwon|
         |140| Haejoon Lee|
         +---+------------+
-        >>> _ = spark.sql("DROP TABLE tblA")
+        >>> _ = spark.sql("DROP TABLE tblA").collect()
         """
         self.mode(mode).options(**options)
         if partitionBy is not None:
diff --git a/python/pyspark/sql/tests/connect/test_parity_readwriter.py 
b/python/pyspark/sql/tests/connect/test_parity_readwriter.py
index 0810ce6bb22..db713c3bee5 100644
--- a/python/pyspark/sql/tests/connect/test_parity_readwriter.py
+++ b/python/pyspark/sql/tests/connect/test_parity_readwriter.py
@@ -26,11 +26,6 @@ class ReadwriterParityTests(ReadwriterTestsMixin, 
ReusedConnectTestCase):
     def test_bucketed_write(self):
         super().test_bucketed_write()
 
-    # TODO(SPARK-42000): saveAsTable fail to find the default source
-    @unittest.skip("Fails in Spark Connect, should enable.")
-    def test_insert_into(self):
-        super().test_insert_into()
-
     # TODO(SPARK-41834): Implement SparkSession.conf
     @unittest.skip("Fails in Spark Connect, should enable.")
     def test_save_and_load(self):


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to