This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new 0469479c751 [SPARK-41818][SPARK-42000][CONNECT] Fix saveAsTable to
find the default source
0469479c751 is described below
commit 0469479c75174acf9873c819627b74c8286cee6a
Author: Takuya UESHIN <[email protected]>
AuthorDate: Tue Feb 14 09:05:57 2023 +0900
[SPARK-41818][SPARK-42000][CONNECT] Fix saveAsTable to find the default
source
### What changes were proposed in this pull request?
Fixes `DataFrameWriter.saveAsTable` to find the default source.
### Why are the changes needed?
Currently `DataFrameWriter.saveAsTable` fails when `format` is not
specified because protobuf defines `source` as required and it will be an empty
string instead of `null`, then `DataFrameWriter` tries to find the data source
`""`.
The `source` field should be optional to let Spark decide the default
source.
### Does this PR introduce _any_ user-facing change?
Users can call `saveAsTable` without `format`.
### How was this patch tested?
Enabled related tests.
Closes #40000 from ueshin/issues/SPARK-42000/saveAsTable.
Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
(cherry picked from commit c1f242eb4c514d7fba8e0c47c96e40cba82a39ad)
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../src/main/protobuf/spark/connect/commands.proto | 4 +--
.../sql/connect/planner/SparkConnectPlanner.scala | 2 +-
python/pyspark/sql/connect/proto/commands_pb2.py | 32 +++++++++++-----------
python/pyspark/sql/connect/proto/commands_pb2.pyi | 15 ++++++++--
python/pyspark/sql/connect/readwriter.py | 5 ++--
python/pyspark/sql/readwriter.py | 18 ++++++------
.../sql/tests/connect/test_parity_readwriter.py | 5 ----
7 files changed, 44 insertions(+), 37 deletions(-)
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
index 73218697577..8872dc626a9 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto
@@ -66,8 +66,8 @@ message WriteOperation {
// (Required) The output of the `input` relation will be persisted according
to the options.
Relation input = 1;
- // (Required) Format value according to the Spark documentation. Examples
are: text, parquet, delta.
- string source = 2;
+ // (Optional) Format value according to the Spark documentation. Examples
are: text, parquet, delta.
+ optional string source = 2;
// The destination of the write operation must be either a path or a table.
oneof save_type {
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 53d494cdcb7..d509a926cdd 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1539,7 +1539,7 @@ class SparkConnectPlanner(val session: SparkSession) {
w.partitionBy(names.toSeq: _*)
}
- if (writeOperation.getSource != null) {
+ if (writeOperation.hasSource) {
w.format(writeOperation.getSource)
}
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py
b/python/pyspark/sql/connect/proto/commands_pb2.py
index f7e9260212e..a4b7fe268ce 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.py
+++ b/python/pyspark/sql/connect/proto/commands_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import types_pb2 as
spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xab\x03\n\x07\x43ommand\x12]\n\x11register_function\x18\x01
\x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02
\x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x1
[...]
+
b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xab\x03\n\x07\x43ommand\x12]\n\x11register_function\x18\x01
\x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02
\x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x1
[...]
)
@@ -151,19 +151,19 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_CREATEDATAFRAMEVIEWCOMMAND._serialized_start = 596
_CREATEDATAFRAMEVIEWCOMMAND._serialized_end = 746
_WRITEOPERATION._serialized_start = 749
- _WRITEOPERATION._serialized_end = 1491
- _WRITEOPERATION_OPTIONSENTRY._serialized_start = 1187
- _WRITEOPERATION_OPTIONSENTRY._serialized_end = 1245
- _WRITEOPERATION_BUCKETBY._serialized_start = 1247
- _WRITEOPERATION_BUCKETBY._serialized_end = 1338
- _WRITEOPERATION_SAVEMODE._serialized_start = 1341
- _WRITEOPERATION_SAVEMODE._serialized_end = 1478
- _WRITEOPERATIONV2._serialized_start = 1494
- _WRITEOPERATIONV2._serialized_end = 2289
- _WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 1187
- _WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 1245
- _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 2061
- _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2127
- _WRITEOPERATIONV2_MODE._serialized_start = 2130
- _WRITEOPERATIONV2_MODE._serialized_end = 2289
+ _WRITEOPERATION._serialized_end = 1507
+ _WRITEOPERATION_OPTIONSENTRY._serialized_start = 1192
+ _WRITEOPERATION_OPTIONSENTRY._serialized_end = 1250
+ _WRITEOPERATION_BUCKETBY._serialized_start = 1252
+ _WRITEOPERATION_BUCKETBY._serialized_end = 1343
+ _WRITEOPERATION_SAVEMODE._serialized_start = 1346
+ _WRITEOPERATION_SAVEMODE._serialized_end = 1483
+ _WRITEOPERATIONV2._serialized_start = 1510
+ _WRITEOPERATIONV2._serialized_end = 2305
+ _WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 1192
+ _WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 1250
+ _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 2077
+ _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2143
+ _WRITEOPERATIONV2_MODE._serialized_start = 2146
+ _WRITEOPERATIONV2_MODE._serialized_end = 2305
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/commands_pb2.pyi
b/python/pyspark/sql/connect/proto/commands_pb2.pyi
index 4bdf1f1ed4e..b8daec597d0 100644
--- a/python/pyspark/sql/connect/proto/commands_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/commands_pb2.pyi
@@ -259,7 +259,7 @@ class WriteOperation(google.protobuf.message.Message):
def input(self) -> pyspark.sql.connect.proto.relations_pb2.Relation:
"""(Required) The output of the `input` relation will be persisted
according to the options."""
source: builtins.str
- """(Required) Format value according to the Spark documentation. Examples
are: text, parquet, delta."""
+ """(Optional) Format value according to the Spark documentation. Examples
are: text, parquet, delta."""
path: builtins.str
table_name: builtins.str
mode: global___WriteOperation.SaveMode.ValueType
@@ -286,7 +286,7 @@ class WriteOperation(google.protobuf.message.Message):
self,
*,
input: pyspark.sql.connect.proto.relations_pb2.Relation | None = ...,
- source: builtins.str = ...,
+ source: builtins.str | None = ...,
path: builtins.str = ...,
table_name: builtins.str = ...,
mode: global___WriteOperation.SaveMode.ValueType = ...,
@@ -298,6 +298,8 @@ class WriteOperation(google.protobuf.message.Message):
def HasField(
self,
field_name: typing_extensions.Literal[
+ "_source",
+ b"_source",
"bucket_by",
b"bucket_by",
"input",
@@ -306,6 +308,8 @@ class WriteOperation(google.protobuf.message.Message):
b"path",
"save_type",
b"save_type",
+ "source",
+ b"source",
"table_name",
b"table_name",
],
@@ -313,6 +317,8 @@ class WriteOperation(google.protobuf.message.Message):
def ClearField(
self,
field_name: typing_extensions.Literal[
+ "_source",
+ b"_source",
"bucket_by",
b"bucket_by",
"input",
@@ -335,6 +341,11 @@ class WriteOperation(google.protobuf.message.Message):
b"table_name",
],
) -> None: ...
+ @typing.overload
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_source", b"_source"]
+ ) -> typing_extensions.Literal["source"] | None: ...
+ @typing.overload
def WhichOneof(
self, oneof_group: typing_extensions.Literal["save_type", b"save_type"]
) -> typing_extensions.Literal["path", "table_name"] | None: ...
diff --git a/python/pyspark/sql/connect/readwriter.py
b/python/pyspark/sql/connect/readwriter.py
index 8724348592e..ee4e3018bb9 100644
--- a/python/pyspark/sql/connect/readwriter.py
+++ b/python/pyspark/sql/connect/readwriter.py
@@ -472,6 +472,8 @@ class DataFrameWriter(OptionUtils):
def insertInto(self, tableName: str, overwrite: Optional[bool] = None) ->
None:
if overwrite is not None:
self.mode("overwrite" if overwrite else "append")
+ elif self._write.mode is None or self._write.mode != "overwrite":
+ self.mode("append")
self.saveAsTable(tableName)
insertInto.__doc__ = PySparkDataFrameWriter.insertInto.__doc__
@@ -695,9 +697,8 @@ def _test() -> None:
del pyspark.sql.connect.readwriter.DataFrameWriter.bucketBy.__doc__
del pyspark.sql.connect.readwriter.DataFrameWriter.sortBy.__doc__
- # TODO(SPARK-41818): Support saveAsTable
+ # TODO(SPARK-42426): insertInto fails when the column names are different
from the table columns
del pyspark.sql.connect.readwriter.DataFrameWriter.insertInto.__doc__
- del pyspark.sql.connect.readwriter.DataFrameWriter.saveAsTable.__doc__
globs["spark"] = (
PySparkSession.builder.appName("sql.connect.readwriter tests")
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index b87fb6528bb..19ff342ac3b 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -468,7 +468,7 @@ class DataFrameReader(OptionUtils):
| 8|
| 9|
+---+
- >>> _ = spark.sql("DROP TABLE tblA")
+ >>> _ = spark.sql("DROP TABLE tblA").collect()
"""
return self._df(self._jreader.table(tableName))
@@ -1234,7 +1234,7 @@ class DataFrameWriter(OptionUtils):
>>> from pyspark.sql.functions import input_file_name
>>> # Write a DataFrame into a Parquet file in a bucketed manner.
- ... _ = spark.sql("DROP TABLE IF EXISTS bucketed_table")
+ ... _ = spark.sql("DROP TABLE IF EXISTS bucketed_table").collect()
>>> spark.createDataFrame([
... (100, "Hyukjin Kwon"), (120, "Hyukjin Kwon"), (140, "Haejoon
Lee")],
... schema=["age", "name"]
@@ -1248,7 +1248,7 @@ class DataFrameWriter(OptionUtils):
|120|Hyukjin Kwon|
|140| Haejoon Lee|
+---+------------+
- >>> _ = spark.sql("DROP TABLE bucketed_table")
+ >>> _ = spark.sql("DROP TABLE bucketed_table").collect()
"""
if not isinstance(numBuckets, int):
raise TypeError("numBuckets should be an int, got
{0}.".format(type(numBuckets)))
@@ -1298,7 +1298,7 @@ class DataFrameWriter(OptionUtils):
>>> from pyspark.sql.functions import input_file_name
>>> # Write a DataFrame into a Parquet file in a sorted-bucketed
manner.
- ... _ = spark.sql("DROP TABLE IF EXISTS sorted_bucketed_table")
+ ... _ = spark.sql("DROP TABLE IF EXISTS
sorted_bucketed_table").collect()
>>> spark.createDataFrame([
... (100, "Hyukjin Kwon"), (120, "Hyukjin Kwon"), (140, "Haejoon
Lee")],
... schema=["age", "name"]
@@ -1313,7 +1313,7 @@ class DataFrameWriter(OptionUtils):
|120|Hyukjin Kwon|
|140| Haejoon Lee|
+---+------------+
- >>> _ = spark.sql("DROP TABLE sorted_bucketed_table")
+ >>> _ = spark.sql("DROP TABLE sorted_bucketed_table").collect()
"""
if isinstance(col, (list, tuple)):
if cols:
@@ -1419,7 +1419,7 @@ class DataFrameWriter(OptionUtils):
Examples
--------
- >>> _ = spark.sql("DROP TABLE IF EXISTS tblA")
+ >>> _ = spark.sql("DROP TABLE IF EXISTS tblA").collect()
>>> df = spark.createDataFrame([
... (100, "Hyukjin Kwon"), (120, "Hyukjin Kwon"), (140, "Haejoon
Lee")],
... schema=["age", "name"]
@@ -1440,7 +1440,7 @@ class DataFrameWriter(OptionUtils):
|140| Haejoon Lee|
|140| Haejoon Lee|
+---+------------+
- >>> _ = spark.sql("DROP TABLE tblA")
+ >>> _ = spark.sql("DROP TABLE tblA").collect()
"""
if overwrite is not None:
self.mode("overwrite" if overwrite else "append")
@@ -1497,7 +1497,7 @@ class DataFrameWriter(OptionUtils):
--------
Creates a table from a DataFrame, and read it back.
- >>> _ = spark.sql("DROP TABLE IF EXISTS tblA")
+ >>> _ = spark.sql("DROP TABLE IF EXISTS tblA").collect()
>>> spark.createDataFrame([
... (100, "Hyukjin Kwon"), (120, "Hyukjin Kwon"), (140, "Haejoon
Lee")],
... schema=["age", "name"]
@@ -1510,7 +1510,7 @@ class DataFrameWriter(OptionUtils):
|120|Hyukjin Kwon|
|140| Haejoon Lee|
+---+------------+
- >>> _ = spark.sql("DROP TABLE tblA")
+ >>> _ = spark.sql("DROP TABLE tblA").collect()
"""
self.mode(mode).options(**options)
if partitionBy is not None:
diff --git a/python/pyspark/sql/tests/connect/test_parity_readwriter.py
b/python/pyspark/sql/tests/connect/test_parity_readwriter.py
index 0810ce6bb22..db713c3bee5 100644
--- a/python/pyspark/sql/tests/connect/test_parity_readwriter.py
+++ b/python/pyspark/sql/tests/connect/test_parity_readwriter.py
@@ -26,11 +26,6 @@ class ReadwriterParityTests(ReadwriterTestsMixin,
ReusedConnectTestCase):
def test_bucketed_write(self):
super().test_bucketed_write()
- # TODO(SPARK-42000): saveAsTable fail to find the default source
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_insert_into(self):
- super().test_insert_into()
-
# TODO(SPARK-41834): Implement SparkSession.conf
@unittest.skip("Fails in Spark Connect, should enable.")
def test_save_and_load(self):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]