This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new 0469479c751 [SPARK-41818][SPARK-42000][CONNECT] Fix saveAsTable to find the default source 0469479c751 is described below commit 0469479c75174acf9873c819627b74c8286cee6a Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Tue Feb 14 09:05:57 2023 +0900 [SPARK-41818][SPARK-42000][CONNECT] Fix saveAsTable to find the default source ### What changes were proposed in this pull request? Fixes `DataFrameWriter.saveAsTable` to find the default source. ### Why are the changes needed? Currently `DataFrameWriter.saveAsTable` fails when `format` is not specified because protobuf defines `source` as required and it will be an empty string instead of `null`, then `DataFrameWriter` tries to find the data source `""`. The `source` field should be optional to let Spark decide the default source. ### Does this PR introduce _any_ user-facing change? Users can call `saveAsTable` without `format`. ### How was this patch tested? Enabled related tests. Closes #40000 from ueshin/issues/SPARK-42000/saveAsTable. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> (cherry picked from commit c1f242eb4c514d7fba8e0c47c96e40cba82a39ad) Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../src/main/protobuf/spark/connect/commands.proto | 4 +-- .../sql/connect/planner/SparkConnectPlanner.scala | 2 +- python/pyspark/sql/connect/proto/commands_pb2.py | 32 +++++++++++----------- python/pyspark/sql/connect/proto/commands_pb2.pyi | 15 ++++++++-- python/pyspark/sql/connect/readwriter.py | 5 ++-- python/pyspark/sql/readwriter.py | 18 ++++++------ .../sql/tests/connect/test_parity_readwriter.py | 5 ---- 7 files changed, 44 insertions(+), 37 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto index 73218697577..8872dc626a9 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto @@ -66,8 +66,8 @@ message WriteOperation { // (Required) The output of the `input` relation will be persisted according to the options. Relation input = 1; - // (Required) Format value according to the Spark documentation. Examples are: text, parquet, delta. - string source = 2; + // (Optional) Format value according to the Spark documentation. Examples are: text, parquet, delta. + optional string source = 2; // The destination of the write operation must be either a path or a table. oneof save_type { diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 53d494cdcb7..d509a926cdd 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1539,7 +1539,7 @@ class SparkConnectPlanner(val session: SparkSession) { w.partitionBy(names.toSeq: _*) } - if (writeOperation.getSource != null) { + if (writeOperation.hasSource) { w.format(writeOperation.getSource) } diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py b/python/pyspark/sql/connect/proto/commands_pb2.py index f7e9260212e..a4b7fe268ce 100644 --- a/python/pyspark/sql/connect/proto/commands_pb2.py +++ b/python/pyspark/sql/connect/proto/commands_pb2.py @@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xab\x03\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x1 [...] + b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xab\x03\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x1 [...] ) @@ -151,19 +151,19 @@ if _descriptor._USE_C_DESCRIPTORS == False: _CREATEDATAFRAMEVIEWCOMMAND._serialized_start = 596 _CREATEDATAFRAMEVIEWCOMMAND._serialized_end = 746 _WRITEOPERATION._serialized_start = 749 - _WRITEOPERATION._serialized_end = 1491 - _WRITEOPERATION_OPTIONSENTRY._serialized_start = 1187 - _WRITEOPERATION_OPTIONSENTRY._serialized_end = 1245 - _WRITEOPERATION_BUCKETBY._serialized_start = 1247 - _WRITEOPERATION_BUCKETBY._serialized_end = 1338 - _WRITEOPERATION_SAVEMODE._serialized_start = 1341 - _WRITEOPERATION_SAVEMODE._serialized_end = 1478 - _WRITEOPERATIONV2._serialized_start = 1494 - _WRITEOPERATIONV2._serialized_end = 2289 - _WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 1187 - _WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 1245 - _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 2061 - _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2127 - _WRITEOPERATIONV2_MODE._serialized_start = 2130 - _WRITEOPERATIONV2_MODE._serialized_end = 2289 + _WRITEOPERATION._serialized_end = 1507 + _WRITEOPERATION_OPTIONSENTRY._serialized_start = 1192 + _WRITEOPERATION_OPTIONSENTRY._serialized_end = 1250 + _WRITEOPERATION_BUCKETBY._serialized_start = 1252 + _WRITEOPERATION_BUCKETBY._serialized_end = 1343 + _WRITEOPERATION_SAVEMODE._serialized_start = 1346 + _WRITEOPERATION_SAVEMODE._serialized_end = 1483 + _WRITEOPERATIONV2._serialized_start = 1510 + _WRITEOPERATIONV2._serialized_end = 2305 + _WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 1192 + _WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 1250 + _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 2077 + _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2143 + _WRITEOPERATIONV2_MODE._serialized_start = 2146 + _WRITEOPERATIONV2_MODE._serialized_end = 2305 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/commands_pb2.pyi b/python/pyspark/sql/connect/proto/commands_pb2.pyi index 4bdf1f1ed4e..b8daec597d0 100644 --- a/python/pyspark/sql/connect/proto/commands_pb2.pyi +++ b/python/pyspark/sql/connect/proto/commands_pb2.pyi @@ -259,7 +259,7 @@ class WriteOperation(google.protobuf.message.Message): def input(self) -> pyspark.sql.connect.proto.relations_pb2.Relation: """(Required) The output of the `input` relation will be persisted according to the options.""" source: builtins.str - """(Required) Format value according to the Spark documentation. Examples are: text, parquet, delta.""" + """(Optional) Format value according to the Spark documentation. Examples are: text, parquet, delta.""" path: builtins.str table_name: builtins.str mode: global___WriteOperation.SaveMode.ValueType @@ -286,7 +286,7 @@ class WriteOperation(google.protobuf.message.Message): self, *, input: pyspark.sql.connect.proto.relations_pb2.Relation | None = ..., - source: builtins.str = ..., + source: builtins.str | None = ..., path: builtins.str = ..., table_name: builtins.str = ..., mode: global___WriteOperation.SaveMode.ValueType = ..., @@ -298,6 +298,8 @@ class WriteOperation(google.protobuf.message.Message): def HasField( self, field_name: typing_extensions.Literal[ + "_source", + b"_source", "bucket_by", b"bucket_by", "input", @@ -306,6 +308,8 @@ class WriteOperation(google.protobuf.message.Message): b"path", "save_type", b"save_type", + "source", + b"source", "table_name", b"table_name", ], @@ -313,6 +317,8 @@ class WriteOperation(google.protobuf.message.Message): def ClearField( self, field_name: typing_extensions.Literal[ + "_source", + b"_source", "bucket_by", b"bucket_by", "input", @@ -335,6 +341,11 @@ class WriteOperation(google.protobuf.message.Message): b"table_name", ], ) -> None: ... + @typing.overload + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_source", b"_source"] + ) -> typing_extensions.Literal["source"] | None: ... + @typing.overload def WhichOneof( self, oneof_group: typing_extensions.Literal["save_type", b"save_type"] ) -> typing_extensions.Literal["path", "table_name"] | None: ... diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py index 8724348592e..ee4e3018bb9 100644 --- a/python/pyspark/sql/connect/readwriter.py +++ b/python/pyspark/sql/connect/readwriter.py @@ -472,6 +472,8 @@ class DataFrameWriter(OptionUtils): def insertInto(self, tableName: str, overwrite: Optional[bool] = None) -> None: if overwrite is not None: self.mode("overwrite" if overwrite else "append") + elif self._write.mode is None or self._write.mode != "overwrite": + self.mode("append") self.saveAsTable(tableName) insertInto.__doc__ = PySparkDataFrameWriter.insertInto.__doc__ @@ -695,9 +697,8 @@ def _test() -> None: del pyspark.sql.connect.readwriter.DataFrameWriter.bucketBy.__doc__ del pyspark.sql.connect.readwriter.DataFrameWriter.sortBy.__doc__ - # TODO(SPARK-41818): Support saveAsTable + # TODO(SPARK-42426): insertInto fails when the column names are different from the table columns del pyspark.sql.connect.readwriter.DataFrameWriter.insertInto.__doc__ - del pyspark.sql.connect.readwriter.DataFrameWriter.saveAsTable.__doc__ globs["spark"] = ( PySparkSession.builder.appName("sql.connect.readwriter tests") diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index b87fb6528bb..19ff342ac3b 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -468,7 +468,7 @@ class DataFrameReader(OptionUtils): | 8| | 9| +---+ - >>> _ = spark.sql("DROP TABLE tblA") + >>> _ = spark.sql("DROP TABLE tblA").collect() """ return self._df(self._jreader.table(tableName)) @@ -1234,7 +1234,7 @@ class DataFrameWriter(OptionUtils): >>> from pyspark.sql.functions import input_file_name >>> # Write a DataFrame into a Parquet file in a bucketed manner. - ... _ = spark.sql("DROP TABLE IF EXISTS bucketed_table") + ... _ = spark.sql("DROP TABLE IF EXISTS bucketed_table").collect() >>> spark.createDataFrame([ ... (100, "Hyukjin Kwon"), (120, "Hyukjin Kwon"), (140, "Haejoon Lee")], ... schema=["age", "name"] @@ -1248,7 +1248,7 @@ class DataFrameWriter(OptionUtils): |120|Hyukjin Kwon| |140| Haejoon Lee| +---+------------+ - >>> _ = spark.sql("DROP TABLE bucketed_table") + >>> _ = spark.sql("DROP TABLE bucketed_table").collect() """ if not isinstance(numBuckets, int): raise TypeError("numBuckets should be an int, got {0}.".format(type(numBuckets))) @@ -1298,7 +1298,7 @@ class DataFrameWriter(OptionUtils): >>> from pyspark.sql.functions import input_file_name >>> # Write a DataFrame into a Parquet file in a sorted-bucketed manner. - ... _ = spark.sql("DROP TABLE IF EXISTS sorted_bucketed_table") + ... _ = spark.sql("DROP TABLE IF EXISTS sorted_bucketed_table").collect() >>> spark.createDataFrame([ ... (100, "Hyukjin Kwon"), (120, "Hyukjin Kwon"), (140, "Haejoon Lee")], ... schema=["age", "name"] @@ -1313,7 +1313,7 @@ class DataFrameWriter(OptionUtils): |120|Hyukjin Kwon| |140| Haejoon Lee| +---+------------+ - >>> _ = spark.sql("DROP TABLE sorted_bucketed_table") + >>> _ = spark.sql("DROP TABLE sorted_bucketed_table").collect() """ if isinstance(col, (list, tuple)): if cols: @@ -1419,7 +1419,7 @@ class DataFrameWriter(OptionUtils): Examples -------- - >>> _ = spark.sql("DROP TABLE IF EXISTS tblA") + >>> _ = spark.sql("DROP TABLE IF EXISTS tblA").collect() >>> df = spark.createDataFrame([ ... (100, "Hyukjin Kwon"), (120, "Hyukjin Kwon"), (140, "Haejoon Lee")], ... schema=["age", "name"] @@ -1440,7 +1440,7 @@ class DataFrameWriter(OptionUtils): |140| Haejoon Lee| |140| Haejoon Lee| +---+------------+ - >>> _ = spark.sql("DROP TABLE tblA") + >>> _ = spark.sql("DROP TABLE tblA").collect() """ if overwrite is not None: self.mode("overwrite" if overwrite else "append") @@ -1497,7 +1497,7 @@ class DataFrameWriter(OptionUtils): -------- Creates a table from a DataFrame, and read it back. - >>> _ = spark.sql("DROP TABLE IF EXISTS tblA") + >>> _ = spark.sql("DROP TABLE IF EXISTS tblA").collect() >>> spark.createDataFrame([ ... (100, "Hyukjin Kwon"), (120, "Hyukjin Kwon"), (140, "Haejoon Lee")], ... schema=["age", "name"] @@ -1510,7 +1510,7 @@ class DataFrameWriter(OptionUtils): |120|Hyukjin Kwon| |140| Haejoon Lee| +---+------------+ - >>> _ = spark.sql("DROP TABLE tblA") + >>> _ = spark.sql("DROP TABLE tblA").collect() """ self.mode(mode).options(**options) if partitionBy is not None: diff --git a/python/pyspark/sql/tests/connect/test_parity_readwriter.py b/python/pyspark/sql/tests/connect/test_parity_readwriter.py index 0810ce6bb22..db713c3bee5 100644 --- a/python/pyspark/sql/tests/connect/test_parity_readwriter.py +++ b/python/pyspark/sql/tests/connect/test_parity_readwriter.py @@ -26,11 +26,6 @@ class ReadwriterParityTests(ReadwriterTestsMixin, ReusedConnectTestCase): def test_bucketed_write(self): super().test_bucketed_write() - # TODO(SPARK-42000): saveAsTable fail to find the default source - @unittest.skip("Fails in Spark Connect, should enable.") - def test_insert_into(self): - super().test_insert_into() - # TODO(SPARK-41834): Implement SparkSession.conf @unittest.skip("Fails in Spark Connect, should enable.") def test_save_and_load(self): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org