This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new dbe23c8e88d [SPARK-42522][CONNECT] Fix DataFrameWriterV2 to find the default source dbe23c8e88d is described below commit dbe23c8e88d1a2968ae1c17ec9ee3029ef7a348a Author: Takuya UESHIN <ues...@databricks.com> AuthorDate: Wed Feb 22 16:53:06 2023 -0400 [SPARK-42522][CONNECT] Fix DataFrameWriterV2 to find the default source ### What changes were proposed in this pull request? Fixes `DataFrameWriterV2` to find the default source. ### Why are the changes needed? Currently `DataFrameWriterV2` in Spark Connect doesn't work without the provider with a weird error: For example: ```py df.writeTo("test_table").create() ``` ``` pyspark.errors.exceptions.connect.SparkConnectGrpcException: (org.apache.spark.SparkClassNotFoundException) [DATA_SOURCE_NOT_FOUND] Failed to find the data source: . Please find packages at `https://spark.apache.org/third-party-projects.html`. ``` ### Does this PR introduce _any_ user-facing change? Users will be able to use `DataFrameWriterV2` without provider as same as PySpark. ### How was this patch tested? Added some tests. Closes #40109 from ueshin/issues/SPARK-42522/writer_v2. Authored-by: Takuya UESHIN <ues...@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../src/main/protobuf/spark/connect/commands.proto | 2 +- .../spark/sql/connect/planner/SparkConnectPlanner.scala | 6 +++--- python/pyspark/sql/connect/proto/commands_pb2.py | 12 ++++++------ python/pyspark/sql/connect/proto/commands_pb2.pyi | 16 ++++++++++++++-- python/pyspark/sql/tests/test_readwriter.py | 12 ++++++++++++ 5 files changed, 36 insertions(+), 12 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto index 7567b0e3d7c..1f2f473a050 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/commands.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/commands.proto @@ -128,7 +128,7 @@ message WriteOperationV2 { // (Optional) A provider for the underlying output data source. Spark's default catalog supports // "parquet", "json", etc. - string provider = 3; + optional string provider = 3; // (Optional) List of columns for partitioning for output table created by `create`, // `createOrReplace`, or `replace` diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index a14d3632d28..268bf02fad9 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1614,7 +1614,7 @@ class SparkConnectPlanner(val session: SparkSession) { writeOperation.getMode match { case proto.WriteOperationV2.Mode.MODE_CREATE => - if (writeOperation.getProvider != null) { + if (writeOperation.hasProvider) { w.using(writeOperation.getProvider).create() } else { w.create() @@ -1626,13 +1626,13 @@ class SparkConnectPlanner(val session: SparkSession) { case proto.WriteOperationV2.Mode.MODE_APPEND => w.append() case proto.WriteOperationV2.Mode.MODE_REPLACE => - if (writeOperation.getProvider != null) { + if (writeOperation.hasProvider) { w.using(writeOperation.getProvider).replace() } else { w.replace() } case proto.WriteOperationV2.Mode.MODE_CREATE_OR_REPLACE => - if (writeOperation.getProvider != null) { + if (writeOperation.hasProvider) { w.using(writeOperation.getProvider).createOrReplace() } else { w.createOrReplace() diff --git a/python/pyspark/sql/connect/proto/commands_pb2.py b/python/pyspark/sql/connect/proto/commands_pb2.py index faa7dd65e2e..c8ade1ea81b 100644 --- a/python/pyspark/sql/connect/proto/commands_pb2.py +++ b/python/pyspark/sql/connect/proto/commands_pb2.py @@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xab\x03\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x1 [...] + b'\n\x1cspark/connect/commands.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"\xab\x03\n\x07\x43ommand\x12]\n\x11register_function\x18\x01 \x01(\x0b\x32..spark.connect.CommonInlineUserDefinedFunctionH\x00R\x10registerFunction\x12H\n\x0fwrite_operation\x18\x02 \x01(\x0b\x32\x1d.spark.connect.WriteOperationH\x00R\x0ewriteOperation\x12_\n\x15\x63reate_dataframe_view\x1 [...] ) @@ -177,11 +177,11 @@ if _descriptor._USE_C_DESCRIPTORS == False: _WRITEOPERATION_SAVEMODE._serialized_start = 1639 _WRITEOPERATION_SAVEMODE._serialized_end = 1776 _WRITEOPERATIONV2._serialized_start = 1803 - _WRITEOPERATIONV2._serialized_end = 2598 + _WRITEOPERATIONV2._serialized_end = 2616 _WRITEOPERATIONV2_OPTIONSENTRY._serialized_start = 1224 _WRITEOPERATIONV2_OPTIONSENTRY._serialized_end = 1282 - _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 2370 - _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2436 - _WRITEOPERATIONV2_MODE._serialized_start = 2439 - _WRITEOPERATIONV2_MODE._serialized_end = 2598 + _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_start = 2375 + _WRITEOPERATIONV2_TABLEPROPERTIESENTRY._serialized_end = 2441 + _WRITEOPERATIONV2_MODE._serialized_start = 2444 + _WRITEOPERATIONV2_MODE._serialized_end = 2603 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/commands_pb2.pyi b/python/pyspark/sql/connect/proto/commands_pb2.pyi index c102624ca44..fb767ead329 100644 --- a/python/pyspark/sql/connect/proto/commands_pb2.pyi +++ b/python/pyspark/sql/connect/proto/commands_pb2.pyi @@ -506,7 +506,7 @@ class WriteOperationV2(google.protobuf.message.Message): *, input: pyspark.sql.connect.proto.relations_pb2.Relation | None = ..., table_name: builtins.str = ..., - provider: builtins.str = ..., + provider: builtins.str | None = ..., partitioning_columns: collections.abc.Iterable[ pyspark.sql.connect.proto.expressions_pb2.Expression ] @@ -519,12 +519,21 @@ class WriteOperationV2(google.protobuf.message.Message): def HasField( self, field_name: typing_extensions.Literal[ - "input", b"input", "overwrite_condition", b"overwrite_condition" + "_provider", + b"_provider", + "input", + b"input", + "overwrite_condition", + b"overwrite_condition", + "provider", + b"provider", ], ) -> builtins.bool: ... def ClearField( self, field_name: typing_extensions.Literal[ + "_provider", + b"_provider", "input", b"input", "mode", @@ -543,5 +552,8 @@ class WriteOperationV2(google.protobuf.message.Message): b"table_properties", ], ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_provider", b"_provider"] + ) -> typing_extensions.Literal["provider"] | None: ... global___WriteOperationV2 = WriteOperationV2 diff --git a/python/pyspark/sql/tests/test_readwriter.py b/python/pyspark/sql/tests/test_readwriter.py index 9cd3e613667..7f9b5e61051 100644 --- a/python/pyspark/sql/tests/test_readwriter.py +++ b/python/pyspark/sql/tests/test_readwriter.py @@ -19,6 +19,7 @@ import os import shutil import tempfile +from pyspark.errors import AnalysisException from pyspark.sql.functions import col from pyspark.sql.readwriter import DataFrameWriterV2 from pyspark.sql.types import StructType, StructField, StringType @@ -215,6 +216,17 @@ class ReadwriterV2TestsMixin: self.assertIsInstance(writer.partitionedBy(bucket(11, col("id"))), tpe) self.assertIsInstance(writer.partitionedBy(bucket(3, "id"), hours(col("ts"))), tpe) + def test_create(self): + df = self.df + with self.table("test_table"): + df.writeTo("test_table").using("parquet").create() + self.assertEqual(100, self.spark.sql("select * from test_table").count()) + + def test_create_without_provider(self): + df = self.df + with self.assertRaisesRegex(AnalysisException, "Hive support is required"): + df.writeTo("test_table").create() + class ReadwriterTests(ReadwriterTestsMixin, ReusedSQLTestCase): pass --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org